diff --git a/DDrawCompat/Common/Hook.cpp b/DDrawCompat/Common/Hook.cpp index 550ae1f..99fa430 100644 --- a/DDrawCompat/Common/Hook.cpp +++ b/DDrawCompat/Common/Hook.cpp @@ -84,6 +84,14 @@ namespace return ntHeaders; } + HMODULE getModuleHandleFromAddress(void* address) + { + HMODULE module = nullptr; + GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(address), &module); + return module; + } + std::filesystem::path getModulePath(HMODULE module) { char path[MAX_PATH] = {}; @@ -94,9 +102,7 @@ namespace std::string funcAddrToStr(void* funcPtr) { std::ostringstream oss; - HMODULE module = nullptr; - GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, - static_cast(funcPtr), &module); + HMODULE module = getModuleHandleFromAddress(funcPtr); oss << getModulePath(module).string() << "+0x" << std::hex << reinterpret_cast(funcPtr) - reinterpret_cast(module); return oss.str(); @@ -160,10 +166,8 @@ namespace Compat for (auto hookFunc : hookFunctions) { - HMODULE module = nullptr; - if (!GetModuleHandleEx( - GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, - static_cast(hookFunc), &module)) + HMODULE module = getModuleHandleFromAddress(hookFunc); + if (!module) { continue; } @@ -297,6 +301,18 @@ namespace Compat } } + // Avoid hooking ntdll stubs (e.g. ntdll/NtdllDialogWndProc_A instead of user32/DefDlgProcA) + if (func && getModuleHandleFromAddress(func) != module && + 0xFF == static_cast(func[0]) && + 0x25 == static_cast(func[1])) + { + FARPROC jmpTarget = **reinterpret_cast(func + 2); + if (getModuleHandleFromAddress(jmpTarget) == module) + { + return jmpTarget; + } + } + return reinterpret_cast(func); }