From 374636252895fad9da1afbd496290e70aa03b160 Mon Sep 17 00:00:00 2001 From: narzoul Date: Mon, 19 Sep 2016 00:12:45 +0200 Subject: [PATCH] Allow hooking multiple user mode display drivers --- DDrawCompat/Common/CompatVtable.h | 114 +++++++++++++--------- DDrawCompat/Common/Hook.cpp | 60 ++++++++++-- DDrawCompat/Common/Hook.h | 6 ++ DDrawCompat/D3dDdi/AdapterFuncs.cpp | 15 ++- DDrawCompat/D3dDdi/DeviceFuncs.cpp | 18 +++- DDrawCompat/D3dDdi/Hooks.cpp | 144 ++++++++-------------------- 6 files changed, 195 insertions(+), 162 deletions(-) diff --git a/DDrawCompat/Common/CompatVtable.h b/DDrawCompat/Common/CompatVtable.h index ed45926..79363da 100644 --- a/DDrawCompat/Common/CompatVtable.h +++ b/DDrawCompat/Common/CompatVtable.h @@ -36,34 +36,63 @@ public: { s_origVtablePtr = vtable; - InitVisitor visitor(*vtable); + HookVisitor visitor(*vtable, s_origVtable); + forEach(visitor); + } + } + + static void hookDriverVtable(HANDLE context, const Vtable* vtable) + { + if (vtable && s_origVtables.find(context) == s_origVtables.end()) + { + HookVisitor visitor(*vtable, s_origVtables[context]); forEach(visitor); } } static Vtable s_origVtable; + static std::map s_origVtables; static const Vtable* s_origVtablePtr; private: - class InitVisitor + class DDrawHook { public: - InitVisitor(const Vtable& origVtable) : m_origVtable(origVtable) {} + template + static decltype(s_compatVtable.*ptr) getCompatFunc(FirstParam) + { + return s_compatVtable.*ptr ? s_compatVtable.*ptr : s_origVtable.*ptr; + } + }; + + class DriverHook + { + public: + template + static decltype(s_compatVtable.*ptr) getCompatFunc(HANDLE context) + { + return s_compatVtable.*ptr ? s_compatVtable.*ptr : s_origVtables.at(context).*ptr; + } + }; + + template + class HookVisitor + { + public: + HookVisitor(const Vtable& srcVtable, Vtable& origVtable) + : m_srcVtable(srcVtable) + , m_origVtable(origVtable) + { + } template void visit() { - s_origVtable.*ptr = m_origVtable.*ptr; - - if (!(s_compatVtable.*ptr)) + m_origVtable.*ptr = m_srcVtable.*ptr; + if (s_compatVtable.*ptr) { - s_threadSafeVtable.*ptr = s_origVtable.*ptr; - s_compatVtable.*ptr = s_origVtable.*ptr; - } - else - { - s_threadSafeVtable.*ptr = getThreadSafeFuncPtr(s_compatVtable.*ptr); - Compat::hookFunction(reinterpret_cast(s_origVtable.*ptr), s_threadSafeVtable.*ptr); + Compat::hookFunction(reinterpret_cast(m_origVtable.*ptr), + getThreadSafeFuncPtr(m_origVtable.*ptr)); } } @@ -73,15 +102,9 @@ private: Compat::Log() << "Hooking function: " << vtableTypeName << "::" << funcName; s_funcNames[getKey()] = vtableTypeName + "::" + funcName; - s_origVtable.*ptr = m_origVtable.*ptr; - - s_threadSafeVtable.*ptr = getThreadSafeFuncPtr(s_compatVtable.*ptr); - Compat::hookFunction(reinterpret_cast(s_origVtable.*ptr), s_threadSafeVtable.*ptr); - - if (!(s_compatVtable.*ptr)) - { - s_compatVtable.*ptr = s_origVtable.*ptr; - } + m_origVtable.*ptr = m_srcVtable.*ptr; + Compat::hookFunction(reinterpret_cast(m_origVtable.*ptr), + getThreadSafeFuncPtr(m_origVtable.*ptr)); } private: @@ -108,38 +131,38 @@ private: return &threadSafeFunc; } - template - static Result STDMETHODCALLTYPE threadSafeFunc(Params... params) + template + static Result STDMETHODCALLTYPE threadSafeFunc(FirstParam firstParam, Params... params) { DDraw::ScopedThreadLock lock; #ifdef _DEBUG - Compat::LogEnter(s_funcNames[getKey()].c_str(), params...); -#endif - - Result result = (s_compatVtable.*ptr)(params...); - -#ifdef _DEBUG - Compat::LogLeave(s_funcNames[getKey()].c_str(), params...) << result; -#endif + const char* funcName = s_funcNames[getKey()].c_str(); + Compat::LogEnter(funcName, firstParam, params...); + Result result = Hook::getCompatFunc(firstParam)(firstParam, params...); + Compat::LogLeave(funcName, firstParam, params...) << result; return result; +#else + return (s_compatVtable.*ptr)(firstParam, params...); +#endif } - template - static void STDMETHODCALLTYPE threadSafeFunc(Params... params) + template + static void STDMETHODCALLTYPE threadSafeFunc(FirstParam firstParam, Params... params) { DDraw::ScopedThreadLock lock; #ifdef _DEBUG - Compat::LogEnter(s_funcNames[getKey()].c_str(), params...); -#endif - - (s_compatVtable.*ptr)(params...); - -#ifdef _DEBUG - Compat::LogLeave(s_funcNames[getKey()].c_str(), params...); + const char* funcName = s_funcNames[getKey()].c_str(); + Compat::LogEnter(funcName, firstParam, params...); + Hook::getCompatFunc(firstParam)(firstParam, params...); + Compat::LogLeave(funcName, firstParam, params...); +#else + (s_compatVtable.*ptr)(firstParam, params...); #endif } - const Vtable& m_origVtable; + const Vtable& m_srcVtable; + Vtable& m_origVtable; }; static Vtable createCompatVtable() @@ -156,21 +179,20 @@ private: } static Vtable s_compatVtable; - static Vtable s_threadSafeVtable; static std::map, std::string> s_funcNames; }; template Vtable CompatVtable::s_origVtable = {}; +template +std::map CompatVtable::s_origVtables; + template const Vtable* CompatVtable::s_origVtablePtr = nullptr; template Vtable CompatVtable::s_compatVtable(getCompatVtable()); -template -Vtable CompatVtable::s_threadSafeVtable = {}; - template std::map, std::string> CompatVtable::s_funcNames; diff --git a/DDrawCompat/Common/Hook.cpp b/DDrawCompat/Common/Hook.cpp index 45ac671..9945cba 100644 --- a/DDrawCompat/Common/Hook.cpp +++ b/DDrawCompat/Common/Hook.cpp @@ -1,5 +1,6 @@ #define WIN32_LEAN_AND_MEAN +#include #include #include @@ -13,12 +14,19 @@ namespace { struct HookedFunctionInfo { + HMODULE module; void* trampoline; void* newFunction; }; std::map g_hookedFunctions; + std::map::iterator findOrigFunc(void* origFunc) + { + return std::find_if(g_hookedFunctions.begin(), g_hookedFunctions.end(), + [=](const auto& i) { return origFunc == i.first || origFunc == i.second.trampoline; }); + } + FARPROC getProcAddress(HMODULE module, const char* procName) { if (!module || !procName) @@ -59,7 +67,7 @@ namespace void hookFunction(const char* funcName, void*& origFuncPtr, void* newFuncPtr) { - const auto it = g_hookedFunctions.find(origFuncPtr); + const auto it = findOrigFunc(origFuncPtr); if (it != g_hookedFunctions.end()) { origFuncPtr = it->second.trampoline; @@ -84,7 +92,23 @@ namespace return; } - g_hookedFunctions[hookedFuncPtr] = { origFuncPtr, newFuncPtr }; + HMODULE module = nullptr; + GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, + reinterpret_cast(hookedFuncPtr), &module); + g_hookedFunctions[hookedFuncPtr] = { module, origFuncPtr, newFuncPtr }; + } + + void unhookFunction(const std::map::iterator& hookedFunc) + { + DetourTransactionBegin(); + DetourDetach(&hookedFunc->second.trampoline, hookedFunc->second.newFunction); + DetourTransactionCommit(); + + if (hookedFunc->second.module) + { + FreeLibrary(hookedFunc->second.module); + } + g_hookedFunctions.erase(hookedFunc); } } @@ -95,9 +119,9 @@ namespace Compat ::hookFunction(nullptr, origFuncPtr, newFuncPtr); } - void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr) + void hookFunction(HMODULE module, const char* funcName, void*& origFuncPtr, void* newFuncPtr) { - FARPROC procAddr = getProcAddress(GetModuleHandle(moduleName), funcName); + FARPROC procAddr = getProcAddress(module, funcName); if (!procAddr) { Compat::LogDebug() << "Failed to load the address of a function: " << funcName; @@ -108,13 +132,31 @@ namespace Compat ::hookFunction(funcName, origFuncPtr, newFuncPtr); } + void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr) + { + HMODULE module = LoadLibrary(moduleName); + if (!module) + { + return; + } + hookFunction(module, funcName, origFuncPtr, newFuncPtr); + FreeLibrary(module); + } + void unhookAllFunctions() { - for (auto& hookedFunc : g_hookedFunctions) + while (!g_hookedFunctions.empty()) { - DetourTransactionBegin(); - DetourDetach(&hookedFunc.second.trampoline, hookedFunc.second.newFunction); - DetourTransactionCommit(); + ::unhookFunction(g_hookedFunctions.begin()); } } -} \ No newline at end of file + + void unhookFunction(void* origFunc) + { + auto it = findOrigFunc(origFunc); + if (it != g_hookedFunctions.end()) + { + ::unhookFunction(it); + } + } +} diff --git a/DDrawCompat/Common/Hook.h b/DDrawCompat/Common/Hook.h index b9440e2..4480666 100644 --- a/DDrawCompat/Common/Hook.h +++ b/DDrawCompat/Common/Hook.h @@ -1,5 +1,9 @@ #pragma once +#define WIN32_LEAN_AND_MEAN + +#include + #define CALL_ORIG_FUNC(func) Compat::getOrigFuncPtr() #define HOOK_FUNCTION(module, func, newFunc) \ @@ -15,6 +19,7 @@ namespace Compat } void hookFunction(void*& origFuncPtr, void* newFuncPtr); + void hookFunction(HMODULE module, const char* funcName, void*& origFuncPtr, void* newFuncPtr); void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr); template @@ -25,4 +30,5 @@ namespace Compat } void unhookAllFunctions(); + void unhookFunction(void* origFunc); } diff --git a/DDrawCompat/D3dDdi/AdapterFuncs.cpp b/DDrawCompat/D3dDdi/AdapterFuncs.cpp index 45ca3b3..9084404 100644 --- a/DDrawCompat/D3dDdi/AdapterFuncs.cpp +++ b/DDrawCompat/D3dDdi/AdapterFuncs.cpp @@ -4,14 +4,24 @@ namespace { + HRESULT APIENTRY closeAdapter(HANDLE hAdapter) + { + HRESULT result = D3dDdi::AdapterFuncs::s_origVtables.at(hAdapter).pfnCloseAdapter(hAdapter); + if (SUCCEEDED(result)) + { + D3dDdi::AdapterFuncs::s_origVtables.erase(hAdapter); + } + return result; + } + HRESULT APIENTRY createDevice(HANDLE hAdapter, D3DDDIARG_CREATEDEVICE* pCreateData) { D3dDdi::DeviceCallbacks::hookVtable(pCreateData->pCallbacks); - HRESULT result = D3dDdi::AdapterFuncs::s_origVtable.pfnCreateDevice( + HRESULT result = D3dDdi::AdapterFuncs::s_origVtables.at(hAdapter).pfnCreateDevice( hAdapter, pCreateData); if (SUCCEEDED(result)) { - D3dDdi::DeviceFuncs::hookVtable(pCreateData->pDeviceFuncs); + D3dDdi::DeviceFuncs::hookDriverVtable(pCreateData->hDevice, pCreateData->pDeviceFuncs); } return result; } @@ -21,6 +31,7 @@ namespace D3dDdi { void AdapterFuncs::setCompatVtable(D3DDDI_ADAPTERFUNCS& vtable) { + vtable.pfnCloseAdapter = &closeAdapter; vtable.pfnCreateDevice = &createDevice; } } diff --git a/DDrawCompat/D3dDdi/DeviceFuncs.cpp b/DDrawCompat/D3dDdi/DeviceFuncs.cpp index df60245..5448a6e 100644 --- a/DDrawCompat/D3dDdi/DeviceFuncs.cpp +++ b/DDrawCompat/D3dDdi/DeviceFuncs.cpp @@ -99,9 +99,23 @@ std::ostream& operator<<(std::ostream& os, const D3DDDIBOX& box) << box.Back; } -namespace D3dDdi +namespace { - void DeviceFuncs::setCompatVtable(D3DDDI_DEVICEFUNCS& /*vtable*/) + HRESULT APIENTRY destroyDevice(HANDLE hDevice) { + HRESULT result = D3dDdi::DeviceFuncs::s_origVtables.at(hDevice).pfnDestroyDevice(hDevice); + if (SUCCEEDED(result)) + { + D3dDdi::DeviceFuncs::s_origVtables.erase(hDevice); + } + return result; + } +} + +namespace D3dDdi +{ + void DeviceFuncs::setCompatVtable(D3DDDI_DEVICEFUNCS& vtable) + { + vtable.pfnDestroyDevice = &destroyDevice; } } diff --git a/DDrawCompat/D3dDdi/Hooks.cpp b/DDrawCompat/D3dDdi/Hooks.cpp index 4de403a..47d89b3 100644 --- a/DDrawCompat/D3dDdi/Hooks.cpp +++ b/DDrawCompat/D3dDdi/Hooks.cpp @@ -1,5 +1,7 @@ #define CINTERFACE +#include + #include #include #include @@ -10,8 +12,6 @@ #include "D3dDdi/AdapterCallbacks.h" #include "D3dDdi/AdapterFuncs.h" -HRESULT APIENTRY OpenAdapter(D3DDDIARG_OPENADAPTER*) { return 0; } - std::ostream& operator<<(std::ostream& os, const D3DDDIARG_OPENADAPTER& data) { return Compat::LogStruct(os) @@ -26,116 +26,67 @@ std::ostream& operator<<(std::ostream& os, const D3DDDIARG_OPENADAPTER& data) namespace { UINT g_ddiVersion = 0; - HMODULE g_umd = nullptr; - - D3DKMT_HANDLE openAdapterFromHdc(HDC hdc); + std::wstring g_hookedUmdFileName; + PFND3DDDI_OPENADAPTER g_origOpenAdapter = nullptr; - void closeAdapter(D3DKMT_HANDLE adapter) + void hookOpenAdapter(const std::wstring& umdFileName); + HRESULT APIENTRY openAdapter(D3DDDIARG_OPENADAPTER* pOpenData); + void unhookOpenAdapter(); + + NTSTATUS APIENTRY d3dKmtQueryAdapterInfo(const D3DKMT_QUERYADAPTERINFO* pData) { - D3DKMT_CLOSEADAPTER closeAdapterData = {}; - closeAdapterData.hAdapter = adapter; - D3DKMTCloseAdapter(&closeAdapterData); + NTSTATUS result = CALL_ORIG_FUNC(D3DKMTQueryAdapterInfo)(pData); + if (SUCCEEDED(result) && KMTQAITYPE_UMDRIVERNAME == pData->Type) + { + auto info = static_cast(pData->pPrivateDriverData); + if (g_hookedUmdFileName != info->UmdFileName) + { + unhookOpenAdapter(); + hookOpenAdapter(info->UmdFileName); + } + } + return result; } - DISPLAY_DEVICE getPrimaryDisplayDevice() + void hookOpenAdapter(const std::wstring& umdFileName) { - DISPLAY_DEVICE dd = {}; - dd.cb = sizeof(dd); - for (DWORD i = 0; - EnumDisplayDevices(nullptr, i, &dd, 0) && !(dd.StateFlags & DISPLAY_DEVICE_PRIMARY_DEVICE); - ++i) + g_hookedUmdFileName = umdFileName; + HMODULE module = LoadLibraryW(umdFileName.c_str()); + if (module) { + Compat::hookFunction(module, "OpenAdapter", + reinterpret_cast(g_origOpenAdapter), &openAdapter); + FreeLibrary(module); } - - if (!(dd.StateFlags & DISPLAY_DEVICE_PRIMARY_DEVICE)) - { - Compat::Log() << "Failed to find the primary display device"; - ZeroMemory(&dd, sizeof(dd)); - } - - return dd; - } - - D3DKMT_UMDFILENAMEINFO getUmdDriverName(D3DKMT_HANDLE adapter) - { - D3DKMT_UMDFILENAMEINFO umdFileNameInfo = {}; - umdFileNameInfo.Version = KMTUMDVERSION_DX9; - - D3DKMT_QUERYADAPTERINFO queryAdapterInfo = {}; - queryAdapterInfo.hAdapter = adapter; - queryAdapterInfo.Type = KMTQAITYPE_UMDRIVERNAME; - queryAdapterInfo.pPrivateDriverData = &umdFileNameInfo; - queryAdapterInfo.PrivateDriverDataSize = sizeof(umdFileNameInfo); - NTSTATUS result = D3DKMTQueryAdapterInfo(&queryAdapterInfo); - if (FAILED(result)) - { - Compat::Log() << "Failed to query the display driver name: " << result; - ZeroMemory(&umdFileNameInfo, sizeof(umdFileNameInfo)); - } - - return umdFileNameInfo; - } - - D3DKMT_UMDFILENAMEINFO getPrimaryUmdDriverName() - { - D3DKMT_UMDFILENAMEINFO umdFileNameInfo = {}; - - DISPLAY_DEVICE dd = getPrimaryDisplayDevice(); - if (!dd.DeviceName) - { - return umdFileNameInfo; - } - - HDC dc = CreateDC(nullptr, dd.DeviceName, nullptr, nullptr); - if (!dc) - { - Compat::Log() << "Failed to create a DC for the primary display device"; - return umdFileNameInfo; - } - - D3DKMT_HANDLE adapter = openAdapterFromHdc(dc); - DeleteDC(dc); - if (!adapter) - { - return umdFileNameInfo; - } - - umdFileNameInfo = getUmdDriverName(adapter); - closeAdapter(adapter); - if (0 == umdFileNameInfo.UmdFileName[0]) - { - return umdFileNameInfo; - } - - Compat::Log() << "Primary display adapter driver: " << umdFileNameInfo.UmdFileName; - return umdFileNameInfo; } HRESULT APIENTRY openAdapter(D3DDDIARG_OPENADAPTER* pOpenData) { Compat::LogEnter("openAdapter", pOpenData); D3dDdi::AdapterCallbacks::hookVtable(pOpenData->pAdapterCallbacks); - HRESULT result = CALL_ORIG_FUNC(OpenAdapter)(pOpenData); + HRESULT result = g_origOpenAdapter(pOpenData); if (SUCCEEDED(result)) { + static std::set hookedUmdFileNames; + if (hookedUmdFileNames.find(g_hookedUmdFileName) == hookedUmdFileNames.end()) + { + Compat::Log() << "Hooking user mode display driver: " << g_hookedUmdFileName.c_str(); + hookedUmdFileNames.insert(g_hookedUmdFileName); + } g_ddiVersion = min(pOpenData->Version, pOpenData->DriverVersion); - D3dDdi::AdapterFuncs::hookVtable(pOpenData->pAdapterFuncs); + D3dDdi::AdapterFuncs::hookDriverVtable(pOpenData->hAdapter, pOpenData->pAdapterFuncs); } Compat::LogLeave("openAdapter", pOpenData) << result; return result; } - D3DKMT_HANDLE openAdapterFromHdc(HDC hdc) + void unhookOpenAdapter() { - D3DKMT_OPENADAPTERFROMHDC openAdapterData = {}; - openAdapterData.hDc = hdc; - NTSTATUS result = D3DKMTOpenAdapterFromHdc(&openAdapterData); - if (FAILED(result)) + if (g_origOpenAdapter) { - Compat::Log() << "Failed to open the primary display adapter: " << result; - return 0; + Compat::unhookFunction(g_origOpenAdapter); + g_hookedUmdFileName.clear(); } - return openAdapterData.hAdapter; } } @@ -148,24 +99,11 @@ namespace D3dDdi void installHooks() { - D3DKMT_UMDFILENAMEINFO primaryUmd = getPrimaryUmdDriverName(); - g_umd = LoadLibraryW(primaryUmd.UmdFileName); - if (!g_umd) - { - Compat::Log() << "Failed to load the primary display driver library"; - } - - char umdFileName[MAX_PATH] = {}; - wcstombs_s(nullptr, umdFileName, primaryUmd.UmdFileName, _TRUNCATE); - Compat::hookFunction( - umdFileName, "OpenAdapter", &openAdapter); + HOOK_FUNCTION(gdi32, D3DKMTQueryAdapterInfo, d3dKmtQueryAdapterInfo); } void uninstallHooks() { - if (g_umd) - { - FreeLibrary(g_umd); - } + unhookOpenAdapter(); } }