1
0
mirror of https://github.com/narzoul/DDrawCompat synced 2024-12-30 08:55:36 +01:00

Allow hooking multiple user mode display drivers

This commit is contained in:
narzoul 2016-09-19 00:12:45 +02:00
parent aabe85db65
commit 3746362528
6 changed files with 195 additions and 162 deletions

View File

@ -36,34 +36,63 @@ public:
{
s_origVtablePtr = vtable;
InitVisitor visitor(*vtable);
HookVisitor<DDrawHook> visitor(*vtable, s_origVtable);
forEach<Vtable>(visitor);
}
}
static void hookDriverVtable(HANDLE context, const Vtable* vtable)
{
if (vtable && s_origVtables.find(context) == s_origVtables.end())
{
HookVisitor<DriverHook> visitor(*vtable, s_origVtables[context]);
forEach<Vtable>(visitor);
}
}
static Vtable s_origVtable;
static std::map<HANDLE, Vtable> s_origVtables;
static const Vtable* s_origVtablePtr;
private:
class InitVisitor
class DDrawHook
{
public:
InitVisitor(const Vtable& origVtable) : m_origVtable(origVtable) {}
template <typename MemberDataPtr, MemberDataPtr ptr, typename FirstParam>
static decltype(s_compatVtable.*ptr) getCompatFunc(FirstParam)
{
return s_compatVtable.*ptr ? s_compatVtable.*ptr : s_origVtable.*ptr;
}
};
class DriverHook
{
public:
template <typename MemberDataPtr, MemberDataPtr ptr>
static decltype(s_compatVtable.*ptr) getCompatFunc(HANDLE context)
{
return s_compatVtable.*ptr ? s_compatVtable.*ptr : s_origVtables.at(context).*ptr;
}
};
template <typename Hook>
class HookVisitor
{
public:
HookVisitor(const Vtable& srcVtable, Vtable& origVtable)
: m_srcVtable(srcVtable)
, m_origVtable(origVtable)
{
}
template <typename MemberDataPtr, MemberDataPtr ptr>
void visit()
{
s_origVtable.*ptr = m_origVtable.*ptr;
if (!(s_compatVtable.*ptr))
m_origVtable.*ptr = m_srcVtable.*ptr;
if (s_compatVtable.*ptr)
{
s_threadSafeVtable.*ptr = s_origVtable.*ptr;
s_compatVtable.*ptr = s_origVtable.*ptr;
}
else
{
s_threadSafeVtable.*ptr = getThreadSafeFuncPtr<MemberDataPtr, ptr>(s_compatVtable.*ptr);
Compat::hookFunction(reinterpret_cast<void*&>(s_origVtable.*ptr), s_threadSafeVtable.*ptr);
Compat::hookFunction(reinterpret_cast<void*&>(m_origVtable.*ptr),
getThreadSafeFuncPtr<MemberDataPtr, ptr>(m_origVtable.*ptr));
}
}
@ -73,15 +102,9 @@ private:
Compat::Log() << "Hooking function: " << vtableTypeName << "::" << funcName;
s_funcNames[getKey<MemberDataPtr, ptr>()] = vtableTypeName + "::" + funcName;
s_origVtable.*ptr = m_origVtable.*ptr;
s_threadSafeVtable.*ptr = getThreadSafeFuncPtr<MemberDataPtr, ptr>(s_compatVtable.*ptr);
Compat::hookFunction(reinterpret_cast<void*&>(s_origVtable.*ptr), s_threadSafeVtable.*ptr);
if (!(s_compatVtable.*ptr))
{
s_compatVtable.*ptr = s_origVtable.*ptr;
}
m_origVtable.*ptr = m_srcVtable.*ptr;
Compat::hookFunction(reinterpret_cast<void*&>(m_origVtable.*ptr),
getThreadSafeFuncPtr<MemberDataPtr, ptr>(m_origVtable.*ptr));
}
private:
@ -108,38 +131,38 @@ private:
return &threadSafeFunc<MemberDataPtr, ptr, Params...>;
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static Result STDMETHODCALLTYPE threadSafeFunc(Params... params)
template <typename MemberDataPtr, MemberDataPtr ptr,
typename Result, typename FirstParam, typename... Params>
static Result STDMETHODCALLTYPE threadSafeFunc(FirstParam firstParam, Params... params)
{
DDraw::ScopedThreadLock lock;
#ifdef _DEBUG
Compat::LogEnter(s_funcNames[getKey<MemberDataPtr, ptr>()].c_str(), params...);
#endif
Result result = (s_compatVtable.*ptr)(params...);
#ifdef _DEBUG
Compat::LogLeave(s_funcNames[getKey<MemberDataPtr, ptr>()].c_str(), params...) << result;
#endif
const char* funcName = s_funcNames[getKey<MemberDataPtr, ptr>()].c_str();
Compat::LogEnter(funcName, firstParam, params...);
Result result = Hook::getCompatFunc<MemberDataPtr, ptr>(firstParam)(firstParam, params...);
Compat::LogLeave(funcName, firstParam, params...) << result;
return result;
#else
return (s_compatVtable.*ptr)(firstParam, params...);
#endif
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename... Params>
static void STDMETHODCALLTYPE threadSafeFunc(Params... params)
template <typename MemberDataPtr, MemberDataPtr ptr, typename FirstParam, typename... Params>
static void STDMETHODCALLTYPE threadSafeFunc(FirstParam firstParam, Params... params)
{
DDraw::ScopedThreadLock lock;
#ifdef _DEBUG
Compat::LogEnter(s_funcNames[getKey<MemberDataPtr, ptr>()].c_str(), params...);
#endif
(s_compatVtable.*ptr)(params...);
#ifdef _DEBUG
Compat::LogLeave(s_funcNames[getKey<MemberDataPtr, ptr>()].c_str(), params...);
const char* funcName = s_funcNames[getKey<MemberDataPtr, ptr>()].c_str();
Compat::LogEnter(funcName, firstParam, params...);
Hook::getCompatFunc<MemberDataPtr, ptr>(firstParam)(firstParam, params...);
Compat::LogLeave(funcName, firstParam, params...);
#else
(s_compatVtable.*ptr)(firstParam, params...);
#endif
}
const Vtable& m_origVtable;
const Vtable& m_srcVtable;
Vtable& m_origVtable;
};
static Vtable createCompatVtable()
@ -156,21 +179,20 @@ private:
}
static Vtable s_compatVtable;
static Vtable s_threadSafeVtable;
static std::map<std::vector<unsigned char>, std::string> s_funcNames;
};
template <typename Vtable>
Vtable CompatVtable<Vtable>::s_origVtable = {};
template <typename Vtable>
std::map<HANDLE, Vtable> CompatVtable<Vtable>::s_origVtables;
template <typename Vtable>
const Vtable* CompatVtable<Vtable>::s_origVtablePtr = nullptr;
template <typename Vtable>
Vtable CompatVtable<Vtable>::s_compatVtable(getCompatVtable());
template <typename Vtable>
Vtable CompatVtable<Vtable>::s_threadSafeVtable = {};
template <typename Vtable>
std::map<std::vector<unsigned char>, std::string> CompatVtable<Vtable>::s_funcNames;

View File

@ -1,5 +1,6 @@
#define WIN32_LEAN_AND_MEAN
#include <algorithm>
#include <map>
#include <utility>
@ -13,12 +14,19 @@ namespace
{
struct HookedFunctionInfo
{
HMODULE module;
void* trampoline;
void* newFunction;
};
std::map<void*, HookedFunctionInfo> g_hookedFunctions;
std::map<void*, HookedFunctionInfo>::iterator findOrigFunc(void* origFunc)
{
return std::find_if(g_hookedFunctions.begin(), g_hookedFunctions.end(),
[=](const auto& i) { return origFunc == i.first || origFunc == i.second.trampoline; });
}
FARPROC getProcAddress(HMODULE module, const char* procName)
{
if (!module || !procName)
@ -59,7 +67,7 @@ namespace
void hookFunction(const char* funcName, void*& origFuncPtr, void* newFuncPtr)
{
const auto it = g_hookedFunctions.find(origFuncPtr);
const auto it = findOrigFunc(origFuncPtr);
if (it != g_hookedFunctions.end())
{
origFuncPtr = it->second.trampoline;
@ -84,7 +92,23 @@ namespace
return;
}
g_hookedFunctions[hookedFuncPtr] = { origFuncPtr, newFuncPtr };
HMODULE module = nullptr;
GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
reinterpret_cast<char*>(hookedFuncPtr), &module);
g_hookedFunctions[hookedFuncPtr] = { module, origFuncPtr, newFuncPtr };
}
void unhookFunction(const std::map<void*, HookedFunctionInfo>::iterator& hookedFunc)
{
DetourTransactionBegin();
DetourDetach(&hookedFunc->second.trampoline, hookedFunc->second.newFunction);
DetourTransactionCommit();
if (hookedFunc->second.module)
{
FreeLibrary(hookedFunc->second.module);
}
g_hookedFunctions.erase(hookedFunc);
}
}
@ -95,9 +119,9 @@ namespace Compat
::hookFunction(nullptr, origFuncPtr, newFuncPtr);
}
void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr)
void hookFunction(HMODULE module, const char* funcName, void*& origFuncPtr, void* newFuncPtr)
{
FARPROC procAddr = getProcAddress(GetModuleHandle(moduleName), funcName);
FARPROC procAddr = getProcAddress(module, funcName);
if (!procAddr)
{
Compat::LogDebug() << "Failed to load the address of a function: " << funcName;
@ -108,13 +132,31 @@ namespace Compat
::hookFunction(funcName, origFuncPtr, newFuncPtr);
}
void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr)
{
HMODULE module = LoadLibrary(moduleName);
if (!module)
{
return;
}
hookFunction(module, funcName, origFuncPtr, newFuncPtr);
FreeLibrary(module);
}
void unhookAllFunctions()
{
for (auto& hookedFunc : g_hookedFunctions)
while (!g_hookedFunctions.empty())
{
DetourTransactionBegin();
DetourDetach(&hookedFunc.second.trampoline, hookedFunc.second.newFunction);
DetourTransactionCommit();
::unhookFunction(g_hookedFunctions.begin());
}
}
}
void unhookFunction(void* origFunc)
{
auto it = findOrigFunc(origFunc);
if (it != g_hookedFunctions.end())
{
::unhookFunction(it);
}
}
}

View File

@ -1,5 +1,9 @@
#pragma once
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#define CALL_ORIG_FUNC(func) Compat::getOrigFuncPtr<decltype(&func), &func>()
#define HOOK_FUNCTION(module, func, newFunc) \
@ -15,6 +19,7 @@ namespace Compat
}
void hookFunction(void*& origFuncPtr, void* newFuncPtr);
void hookFunction(HMODULE module, const char* funcName, void*& origFuncPtr, void* newFuncPtr);
void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr);
template <typename OrigFuncPtr, OrigFuncPtr origFunc>
@ -25,4 +30,5 @@ namespace Compat
}
void unhookAllFunctions();
void unhookFunction(void* origFunc);
}

View File

@ -4,14 +4,24 @@
namespace
{
HRESULT APIENTRY closeAdapter(HANDLE hAdapter)
{
HRESULT result = D3dDdi::AdapterFuncs::s_origVtables.at(hAdapter).pfnCloseAdapter(hAdapter);
if (SUCCEEDED(result))
{
D3dDdi::AdapterFuncs::s_origVtables.erase(hAdapter);
}
return result;
}
HRESULT APIENTRY createDevice(HANDLE hAdapter, D3DDDIARG_CREATEDEVICE* pCreateData)
{
D3dDdi::DeviceCallbacks::hookVtable(pCreateData->pCallbacks);
HRESULT result = D3dDdi::AdapterFuncs::s_origVtable.pfnCreateDevice(
HRESULT result = D3dDdi::AdapterFuncs::s_origVtables.at(hAdapter).pfnCreateDevice(
hAdapter, pCreateData);
if (SUCCEEDED(result))
{
D3dDdi::DeviceFuncs::hookVtable(pCreateData->pDeviceFuncs);
D3dDdi::DeviceFuncs::hookDriverVtable(pCreateData->hDevice, pCreateData->pDeviceFuncs);
}
return result;
}
@ -21,6 +31,7 @@ namespace D3dDdi
{
void AdapterFuncs::setCompatVtable(D3DDDI_ADAPTERFUNCS& vtable)
{
vtable.pfnCloseAdapter = &closeAdapter;
vtable.pfnCreateDevice = &createDevice;
}
}

View File

@ -99,9 +99,23 @@ std::ostream& operator<<(std::ostream& os, const D3DDDIBOX& box)
<< box.Back;
}
namespace D3dDdi
namespace
{
void DeviceFuncs::setCompatVtable(D3DDDI_DEVICEFUNCS& /*vtable*/)
HRESULT APIENTRY destroyDevice(HANDLE hDevice)
{
HRESULT result = D3dDdi::DeviceFuncs::s_origVtables.at(hDevice).pfnDestroyDevice(hDevice);
if (SUCCEEDED(result))
{
D3dDdi::DeviceFuncs::s_origVtables.erase(hDevice);
}
return result;
}
}
namespace D3dDdi
{
void DeviceFuncs::setCompatVtable(D3DDDI_DEVICEFUNCS& vtable)
{
vtable.pfnDestroyDevice = &destroyDevice;
}
}

View File

@ -1,5 +1,7 @@
#define CINTERFACE
#include <set>
#include <Windows.h>
#include <d3d.h>
#include <d3dumddi.h>
@ -10,8 +12,6 @@
#include "D3dDdi/AdapterCallbacks.h"
#include "D3dDdi/AdapterFuncs.h"
HRESULT APIENTRY OpenAdapter(D3DDDIARG_OPENADAPTER*) { return 0; }
std::ostream& operator<<(std::ostream& os, const D3DDDIARG_OPENADAPTER& data)
{
return Compat::LogStruct(os)
@ -26,116 +26,67 @@ std::ostream& operator<<(std::ostream& os, const D3DDDIARG_OPENADAPTER& data)
namespace
{
UINT g_ddiVersion = 0;
HMODULE g_umd = nullptr;
D3DKMT_HANDLE openAdapterFromHdc(HDC hdc);
std::wstring g_hookedUmdFileName;
PFND3DDDI_OPENADAPTER g_origOpenAdapter = nullptr;
void closeAdapter(D3DKMT_HANDLE adapter)
void hookOpenAdapter(const std::wstring& umdFileName);
HRESULT APIENTRY openAdapter(D3DDDIARG_OPENADAPTER* pOpenData);
void unhookOpenAdapter();
NTSTATUS APIENTRY d3dKmtQueryAdapterInfo(const D3DKMT_QUERYADAPTERINFO* pData)
{
D3DKMT_CLOSEADAPTER closeAdapterData = {};
closeAdapterData.hAdapter = adapter;
D3DKMTCloseAdapter(&closeAdapterData);
NTSTATUS result = CALL_ORIG_FUNC(D3DKMTQueryAdapterInfo)(pData);
if (SUCCEEDED(result) && KMTQAITYPE_UMDRIVERNAME == pData->Type)
{
auto info = static_cast<D3DKMT_UMDFILENAMEINFO*>(pData->pPrivateDriverData);
if (g_hookedUmdFileName != info->UmdFileName)
{
unhookOpenAdapter();
hookOpenAdapter(info->UmdFileName);
}
}
return result;
}
DISPLAY_DEVICE getPrimaryDisplayDevice()
void hookOpenAdapter(const std::wstring& umdFileName)
{
DISPLAY_DEVICE dd = {};
dd.cb = sizeof(dd);
for (DWORD i = 0;
EnumDisplayDevices(nullptr, i, &dd, 0) && !(dd.StateFlags & DISPLAY_DEVICE_PRIMARY_DEVICE);
++i)
g_hookedUmdFileName = umdFileName;
HMODULE module = LoadLibraryW(umdFileName.c_str());
if (module)
{
Compat::hookFunction(module, "OpenAdapter",
reinterpret_cast<void*&>(g_origOpenAdapter), &openAdapter);
FreeLibrary(module);
}
if (!(dd.StateFlags & DISPLAY_DEVICE_PRIMARY_DEVICE))
{
Compat::Log() << "Failed to find the primary display device";
ZeroMemory(&dd, sizeof(dd));
}
return dd;
}
D3DKMT_UMDFILENAMEINFO getUmdDriverName(D3DKMT_HANDLE adapter)
{
D3DKMT_UMDFILENAMEINFO umdFileNameInfo = {};
umdFileNameInfo.Version = KMTUMDVERSION_DX9;
D3DKMT_QUERYADAPTERINFO queryAdapterInfo = {};
queryAdapterInfo.hAdapter = adapter;
queryAdapterInfo.Type = KMTQAITYPE_UMDRIVERNAME;
queryAdapterInfo.pPrivateDriverData = &umdFileNameInfo;
queryAdapterInfo.PrivateDriverDataSize = sizeof(umdFileNameInfo);
NTSTATUS result = D3DKMTQueryAdapterInfo(&queryAdapterInfo);
if (FAILED(result))
{
Compat::Log() << "Failed to query the display driver name: " << result;
ZeroMemory(&umdFileNameInfo, sizeof(umdFileNameInfo));
}
return umdFileNameInfo;
}
D3DKMT_UMDFILENAMEINFO getPrimaryUmdDriverName()
{
D3DKMT_UMDFILENAMEINFO umdFileNameInfo = {};
DISPLAY_DEVICE dd = getPrimaryDisplayDevice();
if (!dd.DeviceName)
{
return umdFileNameInfo;
}
HDC dc = CreateDC(nullptr, dd.DeviceName, nullptr, nullptr);
if (!dc)
{
Compat::Log() << "Failed to create a DC for the primary display device";
return umdFileNameInfo;
}
D3DKMT_HANDLE adapter = openAdapterFromHdc(dc);
DeleteDC(dc);
if (!adapter)
{
return umdFileNameInfo;
}
umdFileNameInfo = getUmdDriverName(adapter);
closeAdapter(adapter);
if (0 == umdFileNameInfo.UmdFileName[0])
{
return umdFileNameInfo;
}
Compat::Log() << "Primary display adapter driver: " << umdFileNameInfo.UmdFileName;
return umdFileNameInfo;
}
HRESULT APIENTRY openAdapter(D3DDDIARG_OPENADAPTER* pOpenData)
{
Compat::LogEnter("openAdapter", pOpenData);
D3dDdi::AdapterCallbacks::hookVtable(pOpenData->pAdapterCallbacks);
HRESULT result = CALL_ORIG_FUNC(OpenAdapter)(pOpenData);
HRESULT result = g_origOpenAdapter(pOpenData);
if (SUCCEEDED(result))
{
static std::set<std::wstring> hookedUmdFileNames;
if (hookedUmdFileNames.find(g_hookedUmdFileName) == hookedUmdFileNames.end())
{
Compat::Log() << "Hooking user mode display driver: " << g_hookedUmdFileName.c_str();
hookedUmdFileNames.insert(g_hookedUmdFileName);
}
g_ddiVersion = min(pOpenData->Version, pOpenData->DriverVersion);
D3dDdi::AdapterFuncs::hookVtable(pOpenData->pAdapterFuncs);
D3dDdi::AdapterFuncs::hookDriverVtable(pOpenData->hAdapter, pOpenData->pAdapterFuncs);
}
Compat::LogLeave("openAdapter", pOpenData) << result;
return result;
}
D3DKMT_HANDLE openAdapterFromHdc(HDC hdc)
void unhookOpenAdapter()
{
D3DKMT_OPENADAPTERFROMHDC openAdapterData = {};
openAdapterData.hDc = hdc;
NTSTATUS result = D3DKMTOpenAdapterFromHdc(&openAdapterData);
if (FAILED(result))
if (g_origOpenAdapter)
{
Compat::Log() << "Failed to open the primary display adapter: " << result;
return 0;
Compat::unhookFunction(g_origOpenAdapter);
g_hookedUmdFileName.clear();
}
return openAdapterData.hAdapter;
}
}
@ -148,24 +99,11 @@ namespace D3dDdi
void installHooks()
{
D3DKMT_UMDFILENAMEINFO primaryUmd = getPrimaryUmdDriverName();
g_umd = LoadLibraryW(primaryUmd.UmdFileName);
if (!g_umd)
{
Compat::Log() << "Failed to load the primary display driver library";
}
char umdFileName[MAX_PATH] = {};
wcstombs_s(nullptr, umdFileName, primaryUmd.UmdFileName, _TRUNCATE);
Compat::hookFunction<decltype(&OpenAdapter), &OpenAdapter>(
umdFileName, "OpenAdapter", &openAdapter);
HOOK_FUNCTION(gdi32, D3DKMTQueryAdapterInfo, d3dKmtQueryAdapterInfo);
}
void uninstallHooks()
{
if (g_umd)
{
FreeLibrary(g_umd);
}
unhookOpenAdapter();
}
}