From 75f8f7e689debc1c6cf41d71d0c6115dc3ac3470 Mon Sep 17 00:00:00 2001 From: narzoul Date: Wed, 24 Feb 2021 20:35:15 +0100 Subject: [PATCH] Replace vtable function pointers instead of using Detours --- DDrawCompat/Common/CompatVtable.h | 5 +- DDrawCompat/Common/CompatVtableInstance.h | 22 +--- DDrawCompat/Common/FuncNameVisitor.h | 51 ---------- DDrawCompat/Common/Hook.cpp | 26 ++--- DDrawCompat/Common/Hook.h | 3 + DDrawCompat/Common/Log.h | 2 +- DDrawCompat/Common/LogWrapperVisitor.h | 69 ------------- DDrawCompat/Common/VtableHookVisitor.h | 117 +++++++++++++--------- DDrawCompat/Common/VtableUpdateVisitor.h | 27 ----- DDrawCompat/D3dDdi/D3dDdiVtable.h | 60 ++++------- DDrawCompat/D3dDdi/Hooks.cpp | 3 +- DDrawCompat/DDrawCompat.vcxproj | 3 - DDrawCompat/DDrawCompat.vcxproj.filters | 9 -- DDrawCompat/Direct3d/Direct3dDevice.cpp | 1 + DDrawCompat/Gdi/User32WndProcs.cpp | 1 + 15 files changed, 119 insertions(+), 280 deletions(-) delete mode 100644 DDrawCompat/Common/FuncNameVisitor.h delete mode 100644 DDrawCompat/Common/LogWrapperVisitor.h delete mode 100644 DDrawCompat/Common/VtableUpdateVisitor.h diff --git a/DDrawCompat/Common/CompatVtable.h b/DDrawCompat/Common/CompatVtable.h index 1556d62..78890a0 100644 --- a/DDrawCompat/Common/CompatVtable.h +++ b/DDrawCompat/Common/CompatVtable.h @@ -1,12 +1,13 @@ #pragma once #include +#include template using Vtable = typename std::remove_pointer::type; template -class CompatVtable : public CompatVtableInstance +class CompatVtable : public CompatVtableInstance { public: static const Vtable& getOrigVtable(const Vtable& vtable) @@ -21,7 +22,7 @@ public: s_origVtablePtr = vtable; Vtable compatVtable = {}; Compat::setCompatVtable(compatVtable); - CompatVtableInstance::hookVtable(*vtable, compatVtable); + CompatVtableInstance::hookVtable(*vtable, compatVtable); } } diff --git a/DDrawCompat/Common/CompatVtableInstance.h b/DDrawCompat/Common/CompatVtableInstance.h index fff181e..0ad776f 100644 --- a/DDrawCompat/Common/CompatVtableInstance.h +++ b/DDrawCompat/Common/CompatVtableInstance.h @@ -1,8 +1,6 @@ #pragma once -#include #include -#include #include #define SET_COMPAT_VTABLE(Vtable, CompatInterface) \ @@ -21,26 +19,14 @@ public: static Vtable* s_origVtablePtr; }; -template +template class CompatVtableInstance : public CompatVtableInstanceBase { public: static void hookVtable(const Vtable& origVtable, Vtable compatVtable) { -#ifdef DEBUGLOGS - LogWrapperVisitor logWrapperVisitor(origVtable, compatVtable); - forEach(logWrapperVisitor); -#endif - - VtableHookVisitor vtableHookVisitor(origVtable, s_origVtable, compatVtable); + VtableHookVisitor vtableHookVisitor(origVtable, s_origVtable, compatVtable); forEach(vtableHookVisitor); - -#ifdef DEBUGLOGS - VtableUpdateVisitor vtableUpdateVisitor( - origVtable, s_origVtable, LogWrapperVisitor::s_compatVtable); - forEach(vtableUpdateVisitor); -#endif - s_origVtablePtr = &s_origVtable; } @@ -50,5 +36,5 @@ public: template Vtable* CompatVtableInstanceBase::s_origVtablePtr = nullptr; -template -Vtable CompatVtableInstance::s_origVtable = {}; +template +Vtable CompatVtableInstance::s_origVtable = {}; diff --git a/DDrawCompat/Common/FuncNameVisitor.h b/DDrawCompat/Common/FuncNameVisitor.h deleted file mode 100644 index a816c52..0000000 --- a/DDrawCompat/Common/FuncNameVisitor.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -template -class FuncNameVisitor -{ -public: - template - static void visit(const char* funcName) - { - s_funcNames[getKey()] = s_vtableTypeName + "::" + funcName; - } - - template - static const char* getFuncName() - { - return s_funcNames[getKey()].c_str(); - } - -private: - template - static std::vector getKey() - { - MemberDataPtr mp = ptr; - unsigned char* p = reinterpret_cast(&mp); - return std::vector(p, p + sizeof(mp)); - } - - static std::string getVtableTypeName() - { - std::string name = typeid(Vtable).name(); - if (0 == name.find("struct ")) - { - name = name.substr(name.find(" ") + 1); - } - return name; - } - - static std::string s_vtableTypeName; - static std::map, std::string> s_funcNames; -}; - -template -std::string FuncNameVisitor::s_vtableTypeName(getVtableTypeName()); - -template -std::map, std::string> FuncNameVisitor::s_funcNames; diff --git a/DDrawCompat/Common/Hook.cpp b/DDrawCompat/Common/Hook.cpp index ded6b1d..356c50a 100644 --- a/DDrawCompat/Common/Hook.cpp +++ b/DDrawCompat/Common/Hook.cpp @@ -80,15 +80,6 @@ namespace return nullptr; } - std::string funcAddrToStr(void* funcPtr) - { - std::ostringstream oss; - HMODULE module = Compat::getModuleHandleFromAddress(funcPtr); - oss << getModulePath(module).string() << "+0x" << std::hex << - reinterpret_cast(funcPtr) - reinterpret_cast(module); - return oss.str(); - } - PIMAGE_NT_HEADERS getImageNtHeaders(HMODULE module) { PIMAGE_DOS_HEADER dosHeader = reinterpret_cast(module); @@ -149,12 +140,12 @@ namespace void* const hookedFuncPtr = origFuncPtr; if (stubFuncPtr) { - LOG_DEBUG << "Hooking function: " << funcName << " (" << funcAddrToStr(stubFuncPtr) << " -> " - << funcAddrToStr(origFuncPtr) << ')'; + LOG_DEBUG << "Hooking function: " << funcName << " (" << Compat::funcPtrToStr(stubFuncPtr) << " -> " + << Compat::funcPtrToStr(origFuncPtr) << ')'; } else { - LOG_DEBUG << "Hooking function: " << funcName << " (" << funcAddrToStr(hookedFuncPtr) << ')'; + LOG_DEBUG << "Hooking function: " << funcName << " (" << Compat::funcPtrToStr(hookedFuncPtr) << ')'; } DetourTransactionBegin(); @@ -189,6 +180,15 @@ namespace namespace Compat { + std::string funcPtrToStr(void* funcPtr) + { + std::ostringstream oss; + HMODULE module = Compat::getModuleHandleFromAddress(funcPtr); + oss << getModulePath(module).string() << "+0x" << std::hex << + reinterpret_cast(funcPtr) - reinterpret_cast(module); + return oss.str(); + } + HMODULE getModuleHandleFromAddress(void* address) { HMODULE module = nullptr; @@ -304,7 +304,7 @@ namespace Compat FARPROC* func = findProcAddressInIat(module, importedModuleName, funcName); if (func) { - LOG_DEBUG << "Hooking function via IAT: " << funcName << " (" << funcAddrToStr(*func) << ')'; + LOG_DEBUG << "Hooking function via IAT: " << funcName << " (" << funcPtrToStr(*func) << ')'; DWORD oldProtect = 0; VirtualProtect(func, sizeof(func), PAGE_READWRITE, &oldProtect); *func = static_cast(newFuncPtr); diff --git a/DDrawCompat/Common/Hook.h b/DDrawCompat/Common/Hook.h index 06b8839..b61bdaf 100644 --- a/DDrawCompat/Common/Hook.h +++ b/DDrawCompat/Common/Hook.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #define CALL_ORIG_FUNC(func) Compat::getOrigFuncPtr() @@ -13,6 +15,7 @@ namespace Compat { + std::string funcPtrToStr(void* funcPtr); HMODULE getModuleHandleFromAddress(void* address); template diff --git a/DDrawCompat/Common/Log.h b/DDrawCompat/Common/Log.h index f0572ec..10bf2ab 100644 --- a/DDrawCompat/Common/Log.h +++ b/DDrawCompat/Common/Log.h @@ -18,7 +18,7 @@ #define LOG_FUNC(...) Compat::LogFunc logFunc(__VA_ARGS__) #define LOG_RESULT(...) logFunc.setResult(__VA_ARGS__) #else -#define LOG_DEBUG if (false) Compat::Log() +#define LOG_DEBUG if constexpr (false) Compat::Log() #define LOG_FUNC(...) #define LOG_RESULT(...) __VA_ARGS__ #endif diff --git a/DDrawCompat/Common/LogWrapperVisitor.h b/DDrawCompat/Common/LogWrapperVisitor.h deleted file mode 100644 index 8e832af..0000000 --- a/DDrawCompat/Common/LogWrapperVisitor.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -#include -#include - -template -class LogWrapperVisitor -{ -public: - LogWrapperVisitor(const Vtable& origVtable, Vtable& compatVtable) - : m_origVtable(origVtable) - , m_compatVtable(compatVtable) - { - } - - template - void visit(const char* funcName) - { - FuncNameVisitor::visit(funcName); - - if (!(m_compatVtable.*ptr)) - { - m_compatVtable.*ptr = m_origVtable.*ptr; - } - - s_compatVtable.*ptr = m_compatVtable.*ptr; - m_compatVtable.*ptr = getLoggedFuncPtr(m_compatVtable.*ptr); - } - - static Vtable s_compatVtable; - -private: - template - using FuncPtr = Result(STDMETHODCALLTYPE *)(Params...); - - template - static FuncPtr getLoggedFuncPtr(FuncPtr) - { - return &loggedFunc; - } - - template - static FuncPtr getLoggedFuncPtr(FuncPtr) - { - return &loggedFunc; - } - - template - static Result STDMETHODCALLTYPE loggedFunc(Params... params) - { - const char* funcName = FuncNameVisitor::getFuncName(); - LOG_FUNC(funcName, params...); - return LOG_RESULT((s_compatVtable.*ptr)(params...)); - } - - template - static void STDMETHODCALLTYPE loggedFunc(Params... params) - { - const char* funcName = FuncNameVisitor::getFuncName(); - LOG_FUNC(funcName, params...); - (s_compatVtable.*ptr)(params...); - } - - const Vtable& m_origVtable; - Vtable& m_compatVtable; -}; - -template -Vtable LogWrapperVisitor::s_compatVtable; diff --git a/DDrawCompat/Common/VtableHookVisitor.h b/DDrawCompat/Common/VtableHookVisitor.h index 778608b..52ea40f 100644 --- a/DDrawCompat/Common/VtableHookVisitor.h +++ b/DDrawCompat/Common/VtableHookVisitor.h @@ -1,76 +1,103 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include -struct _D3DDDI_ADAPTERCALLBACKS; -struct _D3DDDI_ADAPTERFUNCS; -struct _D3DDDI_DEVICECALLBACKS; -struct _D3DDDI_DEVICEFUNCS; +#include +#include template -class ScopedVtableFuncLock : public DDraw::ScopedThreadLock {}; +class VtableHookVisitorBase +{ +protected: + template + static std::string& getFuncName() + { + static std::string funcName; + return funcName; + } -template <> -class ScopedVtableFuncLock<_D3DDDI_ADAPTERCALLBACKS> : public D3dDdi::ScopedCriticalSection {}; + static std::string getVtableTypeName() + { + std::string name = typeid(Vtable).name(); + if (0 == name.find("struct ")) + { + name = name.substr(name.find(" ") + 1); + } + return name; + } -template <> -class ScopedVtableFuncLock<_D3DDDI_ADAPTERFUNCS> : public D3dDdi::ScopedCriticalSection {}; + static std::string s_vtableTypeName; +}; -template <> -class ScopedVtableFuncLock<_D3DDDI_DEVICECALLBACKS> : public D3dDdi::ScopedCriticalSection {}; - -template <> -class ScopedVtableFuncLock<_D3DDDI_DEVICEFUNCS> : public D3dDdi::ScopedCriticalSection {}; - -template -class VtableHookVisitor +template +class VtableHookVisitor : public VtableHookVisitorBase { public: - VtableHookVisitor(const Vtable& srcVtable, Vtable& origVtable, const Vtable& compatVtable) - : m_srcVtable(srcVtable) + VtableHookVisitor(const Vtable& hookedVtable, Vtable& origVtable, const Vtable& compatVtable) + : m_hookedVtable(const_cast(hookedVtable)) , m_origVtable(origVtable) { s_compatVtable = compatVtable; } template - void visit(const char* /*funcName*/) + void visit([[maybe_unused]] const char* funcName) { - m_origVtable.*ptr = m_srcVtable.*ptr; - if (m_origVtable.*ptr && s_compatVtable.*ptr) +#ifdef DEBUGLOGS + getFuncName() = s_vtableTypeName + "::" + funcName; + if (!(s_compatVtable.*ptr)) { - Compat::hookFunction(reinterpret_cast(m_origVtable.*ptr), - getThreadSafeFuncPtr(s_compatVtable.*ptr), - FuncNameVisitor::getFuncName()); + s_compatVtable.*ptr = m_hookedVtable.*ptr; + } +#endif + + m_origVtable.*ptr = m_hookedVtable.*ptr; + if (m_hookedVtable.*ptr && s_compatVtable.*ptr) + { + LOG_DEBUG << "Hooking function: " << getFuncName() + << " (" << Compat::funcPtrToStr(m_hookedVtable.*ptr) << ')'; + DWORD oldProtect = 0; + VirtualProtect(&(m_hookedVtable.*ptr), sizeof(m_hookedVtable.*ptr), PAGE_READWRITE, &oldProtect); + m_hookedVtable.*ptr = &hookFunc; + VirtualProtect(&(m_hookedVtable.*ptr), sizeof(m_hookedVtable.*ptr), oldProtect, &oldProtect); } } private: - template - using FuncPtr = Result(STDMETHODCALLTYPE *)(Params...); - template - static FuncPtr getThreadSafeFuncPtr(FuncPtr) + static Result STDMETHODCALLTYPE hookFunc(Params... params) { - return &threadSafeFunc; +#ifdef DEBUGLOGS + const char* funcName = getFuncName().c_str(); +#endif + LOG_FUNC(funcName, params...); + Lock lock; + if constexpr (-1 != instanceId) + { + CompatVtableInstanceBase::s_origVtablePtr = &CompatVtableInstance::s_origVtable; + } + if constexpr (std::is_same_v) + { + (s_compatVtable.*ptr)(params...); + } + else + { + return LOG_RESULT((s_compatVtable.*ptr)(params...)); + } } - template - static Result STDMETHODCALLTYPE threadSafeFunc(Params... params) - { - ScopedVtableFuncLock lock; - CompatVtableInstanceBase::s_origVtablePtr = &CompatVtableInstance::s_origVtable; - return (s_compatVtable.*ptr)(params...); - } - - const Vtable& m_srcVtable; + Vtable& m_hookedVtable; Vtable& m_origVtable; static Vtable s_compatVtable; }; -template -Vtable VtableHookVisitor::s_compatVtable = {}; +#ifdef DEBUGLOGS +template +std::string VtableHookVisitorBase::s_vtableTypeName(getVtableTypeName()); +#endif + +template +Vtable VtableHookVisitor::s_compatVtable = {}; diff --git a/DDrawCompat/Common/VtableUpdateVisitor.h b/DDrawCompat/Common/VtableUpdateVisitor.h deleted file mode 100644 index 3e44447..0000000 --- a/DDrawCompat/Common/VtableUpdateVisitor.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -template -class VtableUpdateVisitor -{ -public: - VtableUpdateVisitor(const Vtable& preHookOrigVtable, const Vtable& postHookOrigVtable, Vtable& vtable) - : m_preHookOrigVtable(preHookOrigVtable) - , m_postHookOrigVtable(postHookOrigVtable) - , m_vtable(vtable) - { - } - - template - void visit(const char* /*funcName*/) - { - if (m_preHookOrigVtable.*ptr == m_vtable.*ptr) - { - m_vtable.*ptr = m_postHookOrigVtable.*ptr; - } - } - -private: - const Vtable& m_preHookOrigVtable; - const Vtable& m_postHookOrigVtable; - Vtable& m_vtable; -}; diff --git a/DDrawCompat/D3dDdi/D3dDdiVtable.h b/DDrawCompat/D3dDdi/D3dDdiVtable.h index 76df51c..39a718c 100644 --- a/DDrawCompat/D3dDdi/D3dDdiVtable.h +++ b/DDrawCompat/D3dDdi/D3dDdiVtable.h @@ -1,7 +1,5 @@ #pragma once -#include - #include #include #include @@ -16,76 +14,59 @@ namespace D3dDdi public: static void hookVtable(HMODULE module, const Vtable* vtable) { - if (!vtable) + if (vtable) { - return; - } - - auto it = s_origModuleVtables.find(module); - if (s_origModuleVtables.end() == it) - { - it = s_origModuleVtables.emplace(module, hookVtableInstance(*vtable, InstanceId<0>())).first; + hookVtableInstance(module, *vtable, InstanceId<0>()); } } - static std::map s_origModuleVtables; static Vtable*& s_origVtablePtr; private: - template - class Visitor + class CopyVisitor { public: - Visitor(Vtable& compatVtable) + CopyVisitor(Vtable& compatVtable, const Vtable& origVtable) : m_compatVtable(compatVtable) + , m_origVtable(origVtable) { } template void visit(const char* /*funcName*/) { - if (!(m_compatVtable.*ptr)) - { - m_compatVtable.*ptr = &threadSafeFunc; - } + m_compatVtable.*ptr = m_origVtable.*ptr; } private: - template - static Result APIENTRY threadSafeFunc(Params... params) - { - D3dDdi::ScopedCriticalSection lock; - return (CompatVtableInstance::s_origVtable.*ptr)(params...); - } - Vtable& m_compatVtable; + const Vtable& m_origVtable; }; template struct InstanceId {}; template - static const Vtable& hookVtableInstance(const Vtable& vtable, InstanceId) + static const Vtable& hookVtableInstance(HMODULE module, const Vtable& vtable, InstanceId) { - static bool isHooked = false; - if (isHooked) + static HMODULE hookedModule = nullptr; + if (hookedModule && hookedModule != module) { - return hookVtableInstance(vtable, InstanceId()); + return hookVtableInstance(module, vtable, InstanceId()); } + hookedModule = module; Vtable compatVtable = {}; + CopyVisitor copyVisitor(compatVtable, vtable); + forEach(copyVisitor); + Compat::setCompatVtable(compatVtable); -#ifndef DEBUGLOGS - Visitor visitor(compatVtable); - forEach(visitor); -#endif - - isHooked = true; - CompatVtableInstance::hookVtable(vtable, compatVtable); - return CompatVtableInstance::s_origVtable; + CompatVtableInstance::hookVtable(vtable, compatVtable); + return CompatVtableInstance::s_origVtable; } - static const Vtable& hookVtableInstance(const Vtable& /*vtable*/, InstanceId) + static const Vtable& hookVtableInstance(HMODULE /*module*/, const Vtable& /*vtable*/, + InstanceId) { Compat::Log() << "ERROR: Cannot hook more than " << Config::maxUserModeDisplayDrivers << " user-mode display drivers. Recompile with Config::maxUserModeDisplayDrivers > " << @@ -95,9 +76,6 @@ namespace D3dDdi } }; - template - std::map D3dDdiVtable::s_origModuleVtables; - template Vtable*& D3dDdiVtable::s_origVtablePtr = CompatVtableInstanceBase::s_origVtablePtr; } diff --git a/DDrawCompat/D3dDdi/Hooks.cpp b/DDrawCompat/D3dDdi/Hooks.cpp index ff6e7d5..366252e 100644 --- a/DDrawCompat/D3dDdi/Hooks.cpp +++ b/DDrawCompat/D3dDdi/Hooks.cpp @@ -41,7 +41,8 @@ namespace { Compat::hookFunction(g_hookedUmdModule, "OpenAdapter", reinterpret_cast(g_origOpenAdapter), &openAdapter); - FreeLibrary(g_hookedUmdModule); + HMODULE module = nullptr; + GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_PIN, umdFileName.c_str(), &module); } } diff --git a/DDrawCompat/DDrawCompat.vcxproj b/DDrawCompat/DDrawCompat.vcxproj index dd3ec3a..669a976 100644 --- a/DDrawCompat/DDrawCompat.vcxproj +++ b/DDrawCompat/DDrawCompat.vcxproj @@ -136,13 +136,10 @@ - - - diff --git a/DDrawCompat/DDrawCompat.vcxproj.filters b/DDrawCompat/DDrawCompat.vcxproj.filters index 8682499..08c7946 100644 --- a/DDrawCompat/DDrawCompat.vcxproj.filters +++ b/DDrawCompat/DDrawCompat.vcxproj.filters @@ -303,21 +303,12 @@ Header Files\D3dDdi - - Header Files\Common - Header Files\Common Header Files\Common - - Header Files\Common - - - Header Files\Common - Header Files\D3dDdi diff --git a/DDrawCompat/Direct3d/Direct3dDevice.cpp b/DDrawCompat/Direct3d/Direct3dDevice.cpp index 915febe..cf3659b 100644 --- a/DDrawCompat/Direct3d/Direct3dDevice.cpp +++ b/DDrawCompat/Direct3d/Direct3dDevice.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include diff --git a/DDrawCompat/Gdi/User32WndProcs.cpp b/DDrawCompat/Gdi/User32WndProcs.cpp index c827ef1..e9a473e 100644 --- a/DDrawCompat/Gdi/User32WndProcs.cpp +++ b/DDrawCompat/Gdi/User32WndProcs.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include