1
0
mirror of https://github.com/narzoul/DDrawCompat synced 2024-12-30 08:55:36 +01:00
DDrawCompat/DDrawCompat/CompatVtable.h
2016-03-21 22:51:56 +01:00

188 lines
5.5 KiB
C++

#pragma once
#include <map>
#include <string>
#include <vector>
#include "DDrawLog.h"
#include "DDrawVtableVisitor.h"
#include "Hook.h"
template <typename Interface>
using Vtable = typename std::remove_pointer<decltype(Interface::lpVtbl)>::type;
namespace Compat
{
struct HookedMethodInfo
{
HookedMethodInfo(void*& updatedOrigMethodPtr, std::map<void*, void*>& vtablePtrToCompatVtable)
: updatedOrigMethodPtr(updatedOrigMethodPtr), vtablePtrToCompatVtable(vtablePtrToCompatVtable)
{
}
void*& updatedOrigMethodPtr;
std::map<void*, void*>& vtablePtrToCompatVtable;
};
extern std::map<void*, HookedMethodInfo> g_hookedMethods;
}
template <typename CompatInterface, typename Interface>
class CompatVtable
{
public:
typedef Interface Interface;
static void hookVtable(Interface& intf)
{
static bool isInitialized = false;
if (!isInitialized)
{
isInitialized = true;
s_vtablePtr = intf.lpVtbl;
s_origVtable = *intf.lpVtbl;
InitVisitor visitor;
forEach<Vtable<Interface>>(visitor);
}
}
static Vtable<Interface> s_origVtable;
private:
class InitVisitor
{
public:
template <typename MemberDataPtr, MemberDataPtr ptr>
void visit()
{
if (!(s_compatVtable.*ptr))
{
s_threadSafeVtable.*ptr = s_origVtable.*ptr;
s_compatVtable.*ptr = s_origVtable.*ptr;
}
else
{
s_threadSafeVtable.*ptr = getThreadSafeFuncPtr<MemberDataPtr, ptr>(s_compatVtable.*ptr);
hookMethod(reinterpret_cast<void*&>(s_origVtable.*ptr), s_threadSafeVtable.*ptr);
}
}
template <typename MemberDataPtr, MemberDataPtr ptr>
void visitDebug(const std::string& vtableTypeName, const std::string& funcName)
{
s_funcNames[getKey<MemberDataPtr, ptr>()] = vtableTypeName + "::" + funcName;
s_threadSafeVtable.*ptr = getThreadSafeFuncPtr<MemberDataPtr, ptr>(s_compatVtable.*ptr);
hookMethod(reinterpret_cast<void*&>(s_origVtable.*ptr), s_threadSafeVtable.*ptr);
if (!(s_compatVtable.*ptr))
{
s_compatVtable.*ptr = s_origVtable.*ptr;
}
}
private:
template <typename Result, typename... Params>
using FuncPtr = Result(STDMETHODCALLTYPE *)(Params...);
template <typename MemberDataPtr, MemberDataPtr ptr>
static std::vector<unsigned char> getKey()
{
MemberDataPtr mp = ptr;
unsigned char* p = reinterpret_cast<unsigned char*>(&mp);
return std::vector<unsigned char>(p, p + sizeof(mp));
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename... Params>
static FuncPtr<Result, Params...> getThreadSafeFuncPtr(FuncPtr<Result, Params...>)
{
return &threadSafeFunc<MemberDataPtr, ptr, Result, Params...>;
}
void hookMethod(void*& origMethodPtr, void* newMethodPtr)
{
auto it = Compat::g_hookedMethods.find(origMethodPtr);
if (it != Compat::g_hookedMethods.end())
{
origMethodPtr = it->second.updatedOrigMethodPtr;
it->second.vtablePtrToCompatVtable[s_vtablePtr] = &s_compatVtable;
}
else
{
s_vtablePtrToCompatVtable[s_vtablePtr] = &s_compatVtable;
Compat::g_hookedMethods.emplace(origMethodPtr,
Compat::HookedMethodInfo(origMethodPtr, s_vtablePtrToCompatVtable));
Compat::beginHookTransaction();
Compat::hookFunction(origMethodPtr, newMethodPtr);
Compat::endHookTransaction();
}
}
template <typename MemberDataPtr, MemberDataPtr ptr, typename Result, typename IntfPtr, typename... Params>
static Result STDMETHODCALLTYPE threadSafeFunc(IntfPtr This, Params... params)
{
Compat::origProcs.AcquireDDThreadLock();
#ifdef _DEBUG
Compat::LogEnter(s_funcNames[getKey<MemberDataPtr, ptr>()].c_str(), This, params...);
#endif
Result result;
auto it = s_vtablePtrToCompatVtable.find(This->lpVtbl);
if (it != s_vtablePtrToCompatVtable.end())
{
Vtable<Interface>& compatVtable = *static_cast<Vtable<Interface>*>(it->second);
result = (compatVtable.*ptr)(This, params...);
}
else
{
result = (s_origVtable.*ptr)(This, params...);
}
#ifdef _DEBUG
Compat::LogLeave(s_funcNames[getKey<MemberDataPtr, ptr>()].c_str(), This, params...) << result;
#endif
Compat::origProcs.ReleaseDDThreadLock();
return result;
}
};
static Vtable<Interface> createCompatVtable()
{
Vtable<Interface> vtable = {};
CompatInterface::setCompatVtable(vtable);
return vtable;
}
static Vtable<Interface>& getCompatVtable()
{
static Vtable<Interface> vtable(createCompatVtable());
return vtable;
}
static Vtable<Interface>* s_vtablePtr;
static Vtable<Interface> s_compatVtable;
static Vtable<Interface> s_threadSafeVtable;
static std::map<void*, void*> s_vtablePtrToCompatVtable;
static std::map<std::vector<unsigned char>, std::string> s_funcNames;
};
template <typename CompatInterface, typename Interface>
Vtable<Interface>* CompatVtable<CompatInterface, Interface>::s_vtablePtr = nullptr;
template <typename CompatInterface, typename Interface>
Vtable<Interface> CompatVtable<CompatInterface, Interface>::s_origVtable = {};
template <typename CompatInterface, typename Interface>
Vtable<Interface> CompatVtable<CompatInterface, Interface>::s_compatVtable(CompatInterface::getCompatVtable());
template <typename CompatInterface, typename Interface>
Vtable<Interface> CompatVtable<CompatInterface, Interface>::s_threadSafeVtable = {};
template <typename CompatInterface, typename Interface>
std::map<void*, void*> CompatVtable<CompatInterface, Interface>::s_vtablePtrToCompatVtable;
template <typename CompatInterface, typename Interface>
std::map<std::vector<unsigned char>, std::string> CompatVtable<CompatInterface, Interface>::s_funcNames;