1
0
mirror of https://github.com/narzoul/DDrawCompat synced 2024-12-30 08:55:36 +01:00

Replace vtable function pointers instead of using Detours

This commit is contained in:
narzoul 2021-02-24 20:35:15 +01:00
parent d268c51d8f
commit 75f8f7e689
15 changed files with 119 additions and 280 deletions

View File

@ -1,12 +1,13 @@
#pragma once
#include <Common/CompatVtableInstance.h>
#include <DDraw/ScopedThreadLock.h>
template <typename Interface>
using Vtable = typename std::remove_pointer<decltype(Interface::lpVtbl)>::type;
template <typename Vtable>
class CompatVtable : public CompatVtableInstance<Vtable>
class CompatVtable : public CompatVtableInstance<Vtable, DDraw::ScopedThreadLock>
{
public:
static const Vtable& getOrigVtable(const Vtable& vtable)
@ -21,7 +22,7 @@ public:
s_origVtablePtr = vtable;
Vtable compatVtable = {};
Compat::setCompatVtable(compatVtable);
CompatVtableInstance<Vtable>::hookVtable(*vtable, compatVtable);
CompatVtableInstance::hookVtable(*vtable, compatVtable);
}
}

View File

@ -1,8 +1,6 @@
#pragma once
#include <Common/LogWrapperVisitor.h>
#include <Common/VtableHookVisitor.h>
#include <Common/VtableUpdateVisitor.h>
#include <Common/VtableVisitor.h>
#define SET_COMPAT_VTABLE(Vtable, CompatInterface) \
@ -21,26 +19,14 @@ public:
static Vtable* s_origVtablePtr;
};
template <typename Vtable, int instanceId = -1>
template <typename Vtable, typename Lock, int instanceId = -1>
class CompatVtableInstance : public CompatVtableInstanceBase<Vtable>
{
public:
static void hookVtable(const Vtable& origVtable, Vtable compatVtable)
{
#ifdef DEBUGLOGS
LogWrapperVisitor<Vtable, instanceId> logWrapperVisitor(origVtable, compatVtable);
forEach<Vtable>(logWrapperVisitor);
#endif
VtableHookVisitor<Vtable, instanceId> vtableHookVisitor(origVtable, s_origVtable, compatVtable);
VtableHookVisitor<Vtable, Lock, instanceId> vtableHookVisitor(origVtable, s_origVtable, compatVtable);
forEach<Vtable>(vtableHookVisitor);
#ifdef DEBUGLOGS
VtableUpdateVisitor<Vtable> vtableUpdateVisitor(
origVtable, s_origVtable, LogWrapperVisitor<Vtable, instanceId>::s_compatVtable);
forEach<Vtable>(vtableUpdateVisitor);
#endif
s_origVtablePtr = &s_origVtable;
}
@ -50,5 +36,5 @@ public:
template <typename Vtable>
Vtable* CompatVtableInstanceBase<Vtable>::s_origVtablePtr = nullptr;
template <typename Vtable, int instanceId>
Vtable CompatVtableInstance<Vtable, instanceId>::s_origVtable = {};
template <typename Vtable, typename Lock, int instanceId>
Vtable CompatVtableInstance<Vtable, Lock, instanceId>::s_origVtable = {};

View File

@ -1,51 +0,0 @@
#pragma once
#include <map>
#include <string>
#include <vector>
#include <typeinfo>
template <typename Vtable>
class FuncNameVisitor
{
public:
template <typename MemberDataPtr, MemberDataPtr ptr>
static void visit(const char* funcName)
{
s_funcNames[getKey<MemberDataPtr, ptr>()] = s_vtableTypeName + "::" + funcName;
}
template <typename MemberDataPtr, MemberDataPtr ptr>
static const char* getFuncName()
{
return s_funcNames[getKey<MemberDataPtr, ptr>()].c_str();
}
private:
template <typename MemberDataPtr, MemberDataPtr ptr>
static std::vector<unsigned char> getKey()
{
MemberDataPtr mp = ptr;
unsigned char* p = reinterpret_cast<unsigned char*>(&mp);
return std::vector<unsigned char>(p, p + sizeof(mp));
}
static std::string getVtableTypeName()
{
std::string name = typeid(Vtable).name();
if (0 == name.find("struct "))
{
name = name.substr(name.find(" ") + 1);
}
return name;
}
static std::string s_vtableTypeName;
static std::map<std::vector<unsigned char>, std::string> s_funcNames;
};
template <typename Vtable>
std::string FuncNameVisitor<Vtable>::s_vtableTypeName(getVtableTypeName());
template <typename Vtable>
std::map<std::vector<unsigned char>, std::string> FuncNameVisitor<Vtable>::s_funcNames;

View File

@ -80,15 +80,6 @@ namespace
return nullptr;
}
std::string funcAddrToStr(void* funcPtr)
{
std::ostringstream oss;
HMODULE module = Compat::getModuleHandleFromAddress(funcPtr);
oss << getModulePath(module).string() << "+0x" << std::hex <<
reinterpret_cast<DWORD>(funcPtr) - reinterpret_cast<DWORD>(module);
return oss.str();
}
PIMAGE_NT_HEADERS getImageNtHeaders(HMODULE module)
{
PIMAGE_DOS_HEADER dosHeader = reinterpret_cast<PIMAGE_DOS_HEADER>(module);
@ -149,12 +140,12 @@ namespace
void* const hookedFuncPtr = origFuncPtr;
if (stubFuncPtr)
{
LOG_DEBUG << "Hooking function: " << funcName << " (" << funcAddrToStr(stubFuncPtr) << " -> "
<< funcAddrToStr(origFuncPtr) << ')';
LOG_DEBUG << "Hooking function: " << funcName << " (" << Compat::funcPtrToStr(stubFuncPtr) << " -> "
<< Compat::funcPtrToStr(origFuncPtr) << ')';
}
else
{
LOG_DEBUG << "Hooking function: " << funcName << " (" << funcAddrToStr(hookedFuncPtr) << ')';
LOG_DEBUG << "Hooking function: " << funcName << " (" << Compat::funcPtrToStr(hookedFuncPtr) << ')';
}
DetourTransactionBegin();
@ -189,6 +180,15 @@ namespace
namespace Compat
{
std::string funcPtrToStr(void* funcPtr)
{
std::ostringstream oss;
HMODULE module = Compat::getModuleHandleFromAddress(funcPtr);
oss << getModulePath(module).string() << "+0x" << std::hex <<
reinterpret_cast<DWORD>(funcPtr) - reinterpret_cast<DWORD>(module);
return oss.str();
}
HMODULE getModuleHandleFromAddress(void* address)
{
HMODULE module = nullptr;
@ -304,7 +304,7 @@ namespace Compat
FARPROC* func = findProcAddressInIat(module, importedModuleName, funcName);
if (func)
{
LOG_DEBUG << "Hooking function via IAT: " << funcName << " (" << funcAddrToStr(*func) << ')';
LOG_DEBUG << "Hooking function via IAT: " << funcName << " (" << funcPtrToStr(*func) << ')';
DWORD oldProtect = 0;
VirtualProtect(func, sizeof(func), PAGE_READWRITE, &oldProtect);
*func = static_cast<FARPROC>(newFuncPtr);

View File

@ -1,5 +1,7 @@
#pragma once
#include <string>
#include <Windows.h>
#define CALL_ORIG_FUNC(func) Compat::getOrigFuncPtr<decltype(&func), &func>()
@ -13,6 +15,7 @@
namespace Compat
{
std::string funcPtrToStr(void* funcPtr);
HMODULE getModuleHandleFromAddress(void* address);
template <typename OrigFuncPtr, OrigFuncPtr origFunc>

View File

@ -18,7 +18,7 @@
#define LOG_FUNC(...) Compat::LogFunc logFunc(__VA_ARGS__)
#define LOG_RESULT(...) logFunc.setResult(__VA_ARGS__)
#else
#define LOG_DEBUG if (false) Compat::Log()
#define LOG_DEBUG if constexpr (false) Compat::Log()
#define LOG_FUNC(...)
#define LOG_RESULT(...) __VA_ARGS__
#endif

View File

@ -1,69 +0,0 @@
#pragma once
#include <Common/FuncNameVisitor.h>
#include <Common/Log.h>
template <typename Vtable, int instanceId = 0>
class LogWrapperVisitor
{
public:
LogWrapperVisitor(const Vtable& origVtable, Vtable& compatVtable)
: m_origVtable(origVtable)
, m_compatVtable(compatVtable)
{
}
template <typename MemberDataPtr, MemberDataPtr ptr>
void visit(const char* funcName)
{
FuncNameVisitor<Vtable>::visit<MemberDataPtr, ptr>(funcName);
if (!(m_compatVtable.*ptr))
{
m_compatVtable.*ptr = m_origVtable.*ptr;
}
s_compatVtable.*ptr = m_compatVtable.*ptr;
m_compatVtable.*ptr = getLoggedFuncPtr<MemberDataPtr, ptr>(m_compatVtable.*ptr);
}
static Vtable s_compatVtable;
private:
template <typename Result, typename... Params>
using FuncPtr = Result(STDMETHODCALLTYPE *)(Params...);
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static FuncPtr<Result, Params...> getLoggedFuncPtr(FuncPtr<Result, Params...>)
{
return &loggedFunc<MemberDataPtr, ptr, Result, Params...>;
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename... Params>
static FuncPtr<void, Params...> getLoggedFuncPtr(FuncPtr<void, Params...>)
{
return &loggedFunc<MemberDataPtr, ptr, Params...>;
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static Result STDMETHODCALLTYPE loggedFunc(Params... params)
{
const char* funcName = FuncNameVisitor<Vtable>::getFuncName<MemberDataPtr, ptr>();
LOG_FUNC(funcName, params...);
return LOG_RESULT((s_compatVtable.*ptr)(params...));
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename... Params>
static void STDMETHODCALLTYPE loggedFunc(Params... params)
{
const char* funcName = FuncNameVisitor<Vtable>::getFuncName<MemberDataPtr, ptr>();
LOG_FUNC(funcName, params...);
(s_compatVtable.*ptr)(params...);
}
const Vtable& m_origVtable;
Vtable& m_compatVtable;
};
template <typename Vtable, int instanceId>
Vtable LogWrapperVisitor<Vtable, instanceId>::s_compatVtable;

View File

@ -1,76 +1,103 @@
#pragma once
#include <Common/FuncNameVisitor.h>
#include <Common/Hook.h>
#include <D3dDdi/ScopedCriticalSection.h>
#include <DDraw/ScopedThreadLock.h>
#include <string>
#include <typeinfo>
#include <type_traits>
struct _D3DDDI_ADAPTERCALLBACKS;
struct _D3DDDI_ADAPTERFUNCS;
struct _D3DDDI_DEVICECALLBACKS;
struct _D3DDDI_DEVICEFUNCS;
#include <Common/Hook.h>
#include <Common/Log.h>
template <typename Vtable>
class ScopedVtableFuncLock : public DDraw::ScopedThreadLock {};
class VtableHookVisitorBase
{
protected:
template <typename MemberDataPtr, MemberDataPtr ptr>
static std::string& getFuncName()
{
static std::string funcName;
return funcName;
}
template <>
class ScopedVtableFuncLock<_D3DDDI_ADAPTERCALLBACKS> : public D3dDdi::ScopedCriticalSection {};
static std::string getVtableTypeName()
{
std::string name = typeid(Vtable).name();
if (0 == name.find("struct "))
{
name = name.substr(name.find(" ") + 1);
}
return name;
}
template <>
class ScopedVtableFuncLock<_D3DDDI_ADAPTERFUNCS> : public D3dDdi::ScopedCriticalSection {};
static std::string s_vtableTypeName;
};
template <>
class ScopedVtableFuncLock<_D3DDDI_DEVICECALLBACKS> : public D3dDdi::ScopedCriticalSection {};
template <>
class ScopedVtableFuncLock<_D3DDDI_DEVICEFUNCS> : public D3dDdi::ScopedCriticalSection {};
template <typename Vtable, int instanceId = 0>
class VtableHookVisitor
template <typename Vtable, typename Lock, int instanceId>
class VtableHookVisitor : public VtableHookVisitorBase<Vtable>
{
public:
VtableHookVisitor(const Vtable& srcVtable, Vtable& origVtable, const Vtable& compatVtable)
: m_srcVtable(srcVtable)
VtableHookVisitor(const Vtable& hookedVtable, Vtable& origVtable, const Vtable& compatVtable)
: m_hookedVtable(const_cast<Vtable&>(hookedVtable))
, m_origVtable(origVtable)
{
s_compatVtable = compatVtable;
}
template <typename MemberDataPtr, MemberDataPtr ptr>
void visit(const char* /*funcName*/)
void visit([[maybe_unused]] const char* funcName)
{
m_origVtable.*ptr = m_srcVtable.*ptr;
if (m_origVtable.*ptr && s_compatVtable.*ptr)
#ifdef DEBUGLOGS
getFuncName<MemberDataPtr, ptr>() = s_vtableTypeName + "::" + funcName;
if (!(s_compatVtable.*ptr))
{
Compat::hookFunction(reinterpret_cast<void*&>(m_origVtable.*ptr),
getThreadSafeFuncPtr<MemberDataPtr, ptr>(s_compatVtable.*ptr),
FuncNameVisitor<Vtable>::getFuncName<MemberDataPtr, ptr>());
s_compatVtable.*ptr = m_hookedVtable.*ptr;
}
#endif
m_origVtable.*ptr = m_hookedVtable.*ptr;
if (m_hookedVtable.*ptr && s_compatVtable.*ptr)
{
LOG_DEBUG << "Hooking function: " << getFuncName<MemberDataPtr, ptr>()
<< " (" << Compat::funcPtrToStr(m_hookedVtable.*ptr) << ')';
DWORD oldProtect = 0;
VirtualProtect(&(m_hookedVtable.*ptr), sizeof(m_hookedVtable.*ptr), PAGE_READWRITE, &oldProtect);
m_hookedVtable.*ptr = &hookFunc<MemberDataPtr, ptr>;
VirtualProtect(&(m_hookedVtable.*ptr), sizeof(m_hookedVtable.*ptr), oldProtect, &oldProtect);
}
}
private:
template <typename Result, typename... Params>
using FuncPtr = Result(STDMETHODCALLTYPE *)(Params...);
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static FuncPtr<Result, Params...> getThreadSafeFuncPtr(FuncPtr<Result, Params...>)
static Result STDMETHODCALLTYPE hookFunc(Params... params)
{
return &threadSafeFunc<MemberDataPtr, ptr, Result, Params...>;
#ifdef DEBUGLOGS
const char* funcName = getFuncName<MemberDataPtr, ptr>().c_str();
#endif
LOG_FUNC(funcName, params...);
Lock lock;
if constexpr (-1 != instanceId)
{
CompatVtableInstanceBase<Vtable>::s_origVtablePtr = &CompatVtableInstance<Vtable, Lock, instanceId>::s_origVtable;
}
if constexpr (std::is_same_v<Result, void>)
{
(s_compatVtable.*ptr)(params...);
}
else
{
return LOG_RESULT((s_compatVtable.*ptr)(params...));
}
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static Result STDMETHODCALLTYPE threadSafeFunc(Params... params)
{
ScopedVtableFuncLock<Vtable> lock;
CompatVtableInstanceBase<Vtable>::s_origVtablePtr = &CompatVtableInstance<Vtable, instanceId>::s_origVtable;
return (s_compatVtable.*ptr)(params...);
}
const Vtable& m_srcVtable;
Vtable& m_hookedVtable;
Vtable& m_origVtable;
static Vtable s_compatVtable;
};
template <typename Vtable, int instanceId>
Vtable VtableHookVisitor<Vtable, instanceId>::s_compatVtable = {};
#ifdef DEBUGLOGS
template <typename Vtable>
std::string VtableHookVisitorBase<Vtable>::s_vtableTypeName(getVtableTypeName());
#endif
template <typename Vtable, typename Lock, int instanceId>
Vtable VtableHookVisitor<Vtable, Lock, instanceId>::s_compatVtable = {};

View File

@ -1,27 +0,0 @@
#pragma once
template <typename Vtable>
class VtableUpdateVisitor
{
public:
VtableUpdateVisitor(const Vtable& preHookOrigVtable, const Vtable& postHookOrigVtable, Vtable& vtable)
: m_preHookOrigVtable(preHookOrigVtable)
, m_postHookOrigVtable(postHookOrigVtable)
, m_vtable(vtable)
{
}
template <typename MemberDataPtr, MemberDataPtr ptr>
void visit(const char* /*funcName*/)
{
if (m_preHookOrigVtable.*ptr == m_vtable.*ptr)
{
m_vtable.*ptr = m_postHookOrigVtable.*ptr;
}
}
private:
const Vtable& m_preHookOrigVtable;
const Vtable& m_postHookOrigVtable;
Vtable& m_vtable;
};

View File

@ -1,7 +1,5 @@
#pragma once
#include <map>
#include <Common/CompatVtableInstance.h>
#include <Common/Log.h>
#include <Common/VtableVisitor.h>
@ -16,76 +14,59 @@ namespace D3dDdi
public:
static void hookVtable(HMODULE module, const Vtable* vtable)
{
if (!vtable)
if (vtable)
{
return;
}
auto it = s_origModuleVtables.find(module);
if (s_origModuleVtables.end() == it)
{
it = s_origModuleVtables.emplace(module, hookVtableInstance(*vtable, InstanceId<0>())).first;
hookVtableInstance(module, *vtable, InstanceId<0>());
}
}
static std::map<HMODULE, const Vtable&> s_origModuleVtables;
static Vtable*& s_origVtablePtr;
private:
template <int instanceId>
class Visitor
class CopyVisitor
{
public:
Visitor(Vtable& compatVtable)
CopyVisitor(Vtable& compatVtable, const Vtable& origVtable)
: m_compatVtable(compatVtable)
, m_origVtable(origVtable)
{
}
template <typename MemberDataPtr, MemberDataPtr ptr>
void visit(const char* /*funcName*/)
{
if (!(m_compatVtable.*ptr))
{
m_compatVtable.*ptr = &threadSafeFunc<MemberDataPtr, ptr>;
}
m_compatVtable.*ptr = m_origVtable.*ptr;
}
private:
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static Result APIENTRY threadSafeFunc(Params... params)
{
D3dDdi::ScopedCriticalSection lock;
return (CompatVtableInstance<Vtable, instanceId>::s_origVtable.*ptr)(params...);
}
Vtable& m_compatVtable;
const Vtable& m_origVtable;
};
template <int instanceId> struct InstanceId {};
template <int instanceId>
static const Vtable& hookVtableInstance(const Vtable& vtable, InstanceId<instanceId>)
static const Vtable& hookVtableInstance(HMODULE module, const Vtable& vtable, InstanceId<instanceId>)
{
static bool isHooked = false;
if (isHooked)
static HMODULE hookedModule = nullptr;
if (hookedModule && hookedModule != module)
{
return hookVtableInstance(vtable, InstanceId<instanceId + 1>());
return hookVtableInstance(module, vtable, InstanceId<instanceId + 1>());
}
hookedModule = module;
Vtable compatVtable = {};
CopyVisitor copyVisitor(compatVtable, vtable);
forEach<Vtable>(copyVisitor);
Compat::setCompatVtable(compatVtable);
#ifndef DEBUGLOGS
Visitor<instanceId> visitor(compatVtable);
forEach<Vtable>(visitor);
#endif
isHooked = true;
CompatVtableInstance<Vtable, instanceId>::hookVtable(vtable, compatVtable);
return CompatVtableInstance<Vtable, instanceId>::s_origVtable;
CompatVtableInstance<Vtable, ScopedCriticalSection, instanceId>::hookVtable(vtable, compatVtable);
return CompatVtableInstance<Vtable, ScopedCriticalSection, instanceId>::s_origVtable;
}
static const Vtable& hookVtableInstance(const Vtable& /*vtable*/, InstanceId<Config::maxUserModeDisplayDrivers>)
static const Vtable& hookVtableInstance(HMODULE /*module*/, const Vtable& /*vtable*/,
InstanceId<Config::maxUserModeDisplayDrivers>)
{
Compat::Log() << "ERROR: Cannot hook more than " << Config::maxUserModeDisplayDrivers <<
" user-mode display drivers. Recompile with Config::maxUserModeDisplayDrivers > " <<
@ -95,9 +76,6 @@ namespace D3dDdi
}
};
template <typename Vtable>
std::map<HMODULE, const Vtable&> D3dDdiVtable<Vtable>::s_origModuleVtables;
template <typename Vtable>
Vtable*& D3dDdiVtable<Vtable>::s_origVtablePtr = CompatVtableInstanceBase<Vtable>::s_origVtablePtr;
}

View File

@ -41,7 +41,8 @@ namespace
{
Compat::hookFunction(g_hookedUmdModule, "OpenAdapter",
reinterpret_cast<void*&>(g_origOpenAdapter), &openAdapter);
FreeLibrary(g_hookedUmdModule);
HMODULE module = nullptr;
GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_PIN, umdFileName.c_str(), &module);
}
}

View File

@ -136,13 +136,10 @@
<ClInclude Include="Common\CompatVtable.h" />
<ClInclude Include="Common\CompatVtableInstance.h" />
<ClInclude Include="Common\CompatWeakPtr.h" />
<ClInclude Include="Common\FuncNameVisitor.h" />
<ClInclude Include="Common\HResultException.h" />
<ClInclude Include="Common\Log.h" />
<ClInclude Include="Common\LogWrapperVisitor.h" />
<ClInclude Include="Common\ScopedSrwLock.h" />
<ClInclude Include="Common\VtableHookVisitor.h" />
<ClInclude Include="Common\VtableUpdateVisitor.h" />
<ClInclude Include="Common\VtableVisitor.h" />
<ClInclude Include="Common\Hook.h" />
<ClInclude Include="Common\ScopedCriticalSection.h" />

View File

@ -303,21 +303,12 @@
<ClInclude Include="D3dDdi\ScopedCriticalSection.h">
<Filter>Header Files\D3dDdi</Filter>
</ClInclude>
<ClInclude Include="Common\LogWrapperVisitor.h">
<Filter>Header Files\Common</Filter>
</ClInclude>
<ClInclude Include="Common\CompatVtableInstance.h">
<Filter>Header Files\Common</Filter>
</ClInclude>
<ClInclude Include="Common\VtableHookVisitor.h">
<Filter>Header Files\Common</Filter>
</ClInclude>
<ClInclude Include="Common\VtableUpdateVisitor.h">
<Filter>Header Files\Common</Filter>
</ClInclude>
<ClInclude Include="Common\FuncNameVisitor.h">
<Filter>Header Files\Common</Filter>
</ClInclude>
<ClInclude Include="D3dDdi\Adapter.h">
<Filter>Header Files\D3dDdi</Filter>
</ClInclude>

View File

@ -1,6 +1,7 @@
#include <Common/CompatPtr.h>
#include <Common/CompatRef.h>
#include <D3dDdi/Device.h>
#include <D3dDdi/ScopedCriticalSection.h>
#include <DDraw/Surfaces/Surface.h>
#include <Direct3d/Direct3dDevice.h>
#include <Direct3d/Types.h>

View File

@ -1,5 +1,6 @@
#include <vector>
#include <D3dDdi/ScopedCriticalSection.h>
#include <Gdi/CompatDc.h>
#include <Gdi/ScrollBar.h>
#include <Gdi/ScrollFunctions.h>