1
0
mirror of https://github.com/narzoul/DDrawCompat synced 2024-12-30 08:55:36 +01:00

499 lines
13 KiB
C++

#undef CINTERFACE
#include <list>
#include <sstream>
#include <string>
#include <Windows.h>
#include <initguid.h>
#include <DbgEng.h>
#include <Common/Hook.h>
#include <Common/Log.h>
#include <Common/Path.h>
#include <Dll/Dll.h>
namespace
{
IDebugClient4* g_debugClient = nullptr;
IDebugControl* g_debugControl = nullptr;
IDebugSymbols* g_debugSymbols = nullptr;
IDebugDataSpaces4* g_debugDataSpaces = nullptr;
ULONG64 g_debugBase = 0;
bool g_isDbgEngInitialized = false;
LONG WINAPI dbgEngWinVerifyTrust(HWND hwnd, GUID* pgActionID, LPVOID pWVTData);
PIMAGE_NT_HEADERS getImageNtHeaders(HMODULE module);
bool initDbgEng();
FARPROC WINAPI dbgEngGetProcAddress(HMODULE hModule, LPCSTR lpProcName)
{
LOG_FUNC("dbgEngGetProcAddress", hModule, lpProcName);
if (0 == strcmp(lpProcName, "WinVerifyTrust"))
{
return LOG_RESULT(reinterpret_cast<FARPROC>(&dbgEngWinVerifyTrust));
}
return LOG_RESULT(GetProcAddress(hModule, lpProcName));
}
LONG WINAPI dbgEngWinVerifyTrust(
[[maybe_unused]] HWND hwnd,
[[maybe_unused]] GUID* pgActionID,
[[maybe_unused]] LPVOID pWVTData)
{
LOG_FUNC("dbgEngWinVerifyTrust", hwnd, pgActionID, pWVTData);
return LOG_RESULT(0);
}
FARPROC* findProcAddressInIat(HMODULE module, const char* procName)
{
if (!module || !procName)
{
return nullptr;
}
PIMAGE_NT_HEADERS ntHeaders = getImageNtHeaders(module);
if (!ntHeaders)
{
return nullptr;
}
char* moduleBase = reinterpret_cast<char*>(module);
PIMAGE_IMPORT_DESCRIPTOR importDesc = reinterpret_cast<PIMAGE_IMPORT_DESCRIPTOR>(moduleBase +
ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);
for (PIMAGE_IMPORT_DESCRIPTOR desc = importDesc;
0 != desc->Characteristics && 0xFFFF != desc->Name;
++desc)
{
auto thunk = reinterpret_cast<PIMAGE_THUNK_DATA>(moduleBase + desc->FirstThunk);
auto origThunk = reinterpret_cast<PIMAGE_THUNK_DATA>(moduleBase + desc->OriginalFirstThunk);
while (0 != thunk->u1.AddressOfData && 0 != origThunk->u1.AddressOfData)
{
if (!(origThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG))
{
auto origImport = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>(moduleBase + origThunk->u1.AddressOfData);
if (0 == strcmp(origImport->Name, procName))
{
return reinterpret_cast<FARPROC*>(&thunk->u1.Function);
}
}
++thunk;
++origThunk;
}
}
return nullptr;
}
PIMAGE_NT_HEADERS getImageNtHeaders(HMODULE module)
{
PIMAGE_DOS_HEADER dosHeader = reinterpret_cast<PIMAGE_DOS_HEADER>(module);
if (IMAGE_DOS_SIGNATURE != dosHeader->e_magic)
{
return nullptr;
}
PIMAGE_NT_HEADERS ntHeaders = reinterpret_cast<PIMAGE_NT_HEADERS>(
reinterpret_cast<char*>(dosHeader) + dosHeader->e_lfanew);
if (IMAGE_NT_SIGNATURE != ntHeaders->Signature)
{
return nullptr;
}
return ntHeaders;
}
unsigned getInstructionSize(void* instruction)
{
const unsigned MAX_INSTRUCTION_SIZE = 15;
HRESULT result = g_debugDataSpaces->WriteVirtual(g_debugBase, instruction, MAX_INSTRUCTION_SIZE, nullptr);
if (FAILED(result))
{
LOG_ONCE("ERROR: DbgEng: WriteVirtual failed: " << Compat::hex(result));
return 0;
}
ULONG64 endOffset = 0;
result = g_debugControl->Disassemble(g_debugBase, 0, nullptr, 0, nullptr, &endOffset);
if (FAILED(result))
{
LOG_ONCE("ERROR: DbgEng: Disassemble failed: " << Compat::hex(result) << " "
<< Compat::hexDump(instruction, MAX_INSTRUCTION_SIZE));
return 0;
}
return static_cast<unsigned>(endOffset - g_debugBase);
}
void hookFunction(void*& origFuncPtr, void* newFuncPtr, const char* funcName)
{
BYTE* targetFunc = static_cast<BYTE*>(origFuncPtr);
std::ostringstream oss;
oss << Compat::funcPtrToStr(targetFunc) << ' ';
char origFuncPtrStr[20] = {};
if (!funcName)
{
sprintf_s(origFuncPtrStr, "%p", origFuncPtr);
funcName = origFuncPtrStr;
}
auto prevTargetFunc = targetFunc;
while (true)
{
unsigned instructionSize = 0;
if (0xE9 == targetFunc[0])
{
instructionSize = 5;
targetFunc += instructionSize + *reinterpret_cast<int*>(targetFunc + 1);
}
else if (0xEB == targetFunc[0])
{
instructionSize = 2;
targetFunc += instructionSize + *reinterpret_cast<signed char*>(targetFunc + 1);
}
else if (0xFF == targetFunc[0] && 0x25 == targetFunc[1])
{
instructionSize = 6;
targetFunc = **reinterpret_cast<BYTE***>(targetFunc + 2);
if (Compat::getModuleHandleFromAddress(targetFunc) == Compat::getModuleHandleFromAddress(prevTargetFunc))
{
targetFunc = prevTargetFunc;
break;
}
}
else
{
break;
}
Compat::LogStream(oss) << Compat::hexDump(prevTargetFunc, instructionSize) << " -> "
<< Compat::funcPtrToStr(targetFunc) << ' ';
prevTargetFunc = targetFunc;
}
if (Compat::getModuleHandleFromAddress(targetFunc) == Dll::g_currentModule)
{
LOG_INFO << "ERROR: Target function is already hooked: " << funcName;
return;
}
if (!initDbgEng())
{
return;
}
const DWORD trampolineSize = 32;
BYTE* trampoline = static_cast<BYTE*>(
VirtualAlloc(nullptr, trampolineSize, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE));
BYTE* src = targetFunc;
BYTE* dst = trampoline;
while (src - targetFunc < 5)
{
unsigned instructionSize = getInstructionSize(src);
if (0 == instructionSize)
{
return;
}
memcpy(dst, src, instructionSize);
if (0xE8 == *src && 5 == instructionSize)
{
*reinterpret_cast<int*>(dst + 1) += src - dst;
}
src += instructionSize;
dst += instructionSize;
}
LOG_DEBUG << "Hooking function: " << funcName
<< " (" << oss.str() << Compat::hexDump(targetFunc, src - targetFunc) << ')';
*dst = 0xE9;
*reinterpret_cast<int*>(dst + 1) = src - (dst + 5);
DWORD oldProtect = 0;
VirtualProtect(trampoline, trampolineSize, PAGE_EXECUTE_READ, &oldProtect);
VirtualProtect(targetFunc, src - targetFunc, PAGE_EXECUTE_READWRITE, &oldProtect);
targetFunc[0] = 0xE9;
*reinterpret_cast<int*>(targetFunc + 1) = static_cast<BYTE*>(newFuncPtr) - (targetFunc + 5);
memset(targetFunc + 5, 0xCC, src - targetFunc - 5);
VirtualProtect(targetFunc, src - targetFunc, PAGE_EXECUTE_READ, &oldProtect);
FlushInstructionCache(GetCurrentProcess(), nullptr, 0);
HMODULE module = nullptr;
GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_PIN,
reinterpret_cast<char*>(targetFunc), &module);
origFuncPtr = trampoline;
}
bool initDbgEng()
{
if (g_isDbgEngInitialized)
{
return 0 != g_debugBase;
}
g_isDbgEngInitialized = true;
if (!GetModuleHandle("dbghelp.dll"))
{
LoadLibraryW((Compat::getSystemPath() / "dbghelp.dll").c_str());
}
auto dbgEng = LoadLibraryW((Compat::getSystemPath() / "dbgeng.dll").c_str());
if (!dbgEng)
{
LOG_INFO << "ERROR: DbgEng: failed to load library";
return false;
}
Compat::hookIatFunction(dbgEng, "GetProcAddress", dbgEngGetProcAddress);
auto debugCreate = reinterpret_cast<decltype(&DebugCreate)>(Compat::getProcAddress(dbgEng, "DebugCreate"));
if (!debugCreate)
{
LOG_INFO << "ERROR: DbgEng: DebugCreate not found";
return false;
}
HRESULT result = S_OK;
if (FAILED(result = debugCreate(IID_IDebugClient4, reinterpret_cast<void**>(&g_debugClient))) ||
FAILED(result = g_debugClient->QueryInterface(IID_IDebugControl, reinterpret_cast<void**>(&g_debugControl))) ||
FAILED(result = g_debugClient->QueryInterface(IID_IDebugSymbols, reinterpret_cast<void**>(&g_debugSymbols))) ||
FAILED(result = g_debugClient->QueryInterface(IID_IDebugDataSpaces4, reinterpret_cast<void**>(&g_debugDataSpaces))))
{
LOG_INFO << "ERROR: DbgEng: object creation failed: " << Compat::hex(result);
return false;
}
result = g_debugClient->OpenDumpFileWide(Compat::getModulePath(Dll::g_currentModule).c_str(), 0);
if (FAILED(result))
{
LOG_INFO << "ERROR: DbgEng: OpenDumpFile failed: " << Compat::hex(result);
return false;
}
g_debugControl->SetEngineOptions(DEBUG_ENGOPT_DISABLE_MODULE_SYMBOL_LOAD);
result = g_debugControl->WaitForEvent(0, INFINITE);
if (FAILED(result))
{
LOG_INFO << "ERROR: DbgEng: WaitForEvent failed: " << Compat::hex(result);
return false;
}
DEBUG_MODULE_PARAMETERS dmp = {};
result = g_debugSymbols->GetModuleParameters(1, 0, 0, &dmp);
if (FAILED(result))
{
LOG_INFO << "ERROR: DbgEng: GetModuleParameters failed: " << Compat::hex(result);
return false;
}
ULONG size = 0;
result = g_debugDataSpaces->GetValidRegionVirtual(dmp.Base, dmp.Size, &g_debugBase, &size);
if (FAILED(result) || 0 == g_debugBase)
{
LOG_INFO << "ERROR: DbgEng: GetValidRegionVirtual failed: " << Compat::hex(result);
return false;
}
return true;
}
}
namespace Compat
{
void closeDbgEng()
{
if (g_debugClient)
{
g_debugClient->EndSession(DEBUG_END_PASSIVE);
}
if (g_debugDataSpaces)
{
g_debugDataSpaces->Release();
g_debugDataSpaces = nullptr;
}
if (g_debugSymbols)
{
g_debugSymbols->Release();
g_debugSymbols = nullptr;
}
if (g_debugControl)
{
g_debugControl->Release();
g_debugControl = nullptr;
}
if (g_debugClient)
{
g_debugClient->Release();
g_debugClient = nullptr;
}
g_debugBase = 0;
g_isDbgEngInitialized = false;
}
std::string funcPtrToStr(void* funcPtr)
{
std::ostringstream oss;
HMODULE module = Compat::getModuleHandleFromAddress(funcPtr);
if (module)
{
oss << Compat::getModulePath(module).u8string() << "+0x" << std::hex <<
reinterpret_cast<DWORD>(funcPtr) - reinterpret_cast<DWORD>(module);
}
else
{
oss << funcPtr;
}
return oss.str();
}
HMODULE getModuleHandleFromAddress(void* address)
{
HMODULE module = nullptr;
GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
static_cast<char*>(address), &module);
return module;
}
FARPROC getProcAddress(HMODULE module, const char* procName)
{
if (!module || !procName)
{
return nullptr;
}
PIMAGE_NT_HEADERS ntHeaders = getImageNtHeaders(module);
if (!ntHeaders)
{
return nullptr;
}
char* moduleBase = reinterpret_cast<char*>(module);
PIMAGE_EXPORT_DIRECTORY exportDir = reinterpret_cast<PIMAGE_EXPORT_DIRECTORY>(
moduleBase + ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT].VirtualAddress);
auto exportDirSize = ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT].Size;
DWORD* rvaOfNames = reinterpret_cast<DWORD*>(moduleBase + exportDir->AddressOfNames);
WORD* nameOrds = reinterpret_cast<WORD*>(moduleBase + exportDir->AddressOfNameOrdinals);
DWORD* rvaOfFunctions = reinterpret_cast<DWORD*>(moduleBase + exportDir->AddressOfFunctions);
char* func = nullptr;
if (0 == HIWORD(procName))
{
WORD ord = LOWORD(procName);
if (ord < exportDir->Base || ord >= exportDir->Base + exportDir->NumberOfFunctions)
{
return nullptr;
}
func = moduleBase + rvaOfFunctions[ord - exportDir->Base];
}
else
{
for (DWORD i = 0; i < exportDir->NumberOfNames; ++i)
{
if (0 == strcmp(procName, moduleBase + rvaOfNames[i]))
{
func = moduleBase + rvaOfFunctions[nameOrds[i]];
}
}
}
if (func &&
func >= reinterpret_cast<char*>(exportDir) &&
func < reinterpret_cast<char*>(exportDir) + exportDirSize)
{
std::string forw(func);
auto separatorPos = forw.find_first_of('.');
if (std::string::npos == separatorPos)
{
return nullptr;
}
HMODULE forwModule = GetModuleHandle(forw.substr(0, separatorPos).c_str());
std::string forwFuncName = forw.substr(separatorPos + 1);
if ('#' == forwFuncName[0])
{
int32_t ord = std::atoi(forwFuncName.substr(1).c_str());
if (ord < 0 || ord > 0xFFFF)
{
return nullptr;
}
return getProcAddress(forwModule, reinterpret_cast<const char*>(ord));
}
else
{
return getProcAddress(forwModule, forwFuncName.c_str());
}
}
return reinterpret_cast<FARPROC>(func);
}
void hookFunction(void*& origFuncPtr, void* newFuncPtr, const char* funcName)
{
::hookFunction(origFuncPtr, newFuncPtr, funcName);
}
void hookFunction(HMODULE module, const char* funcName, void*& origFuncPtr, void* newFuncPtr)
{
FARPROC procAddr = getProcAddress(module, funcName);
if (!procAddr)
{
LOG_DEBUG << "ERROR: Failed to load the address of a function: " << funcName;
return;
}
origFuncPtr = procAddr;
::hookFunction(origFuncPtr, newFuncPtr, funcName);
}
void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr)
{
HMODULE module = LoadLibrary(moduleName);
if (!module)
{
return;
}
hookFunction(module, funcName, origFuncPtr, newFuncPtr);
FreeLibrary(module);
}
void hookIatFunction(HMODULE module, const char* funcName, void* newFuncPtr)
{
FARPROC* func = findProcAddressInIat(module, funcName);
if (func)
{
LOG_DEBUG << "Hooking function via IAT: " << funcName << " (" << funcPtrToStr(*func) << ')';
DWORD oldProtect = 0;
VirtualProtect(func, sizeof(func), PAGE_READWRITE, &oldProtect);
*func = static_cast<FARPROC>(newFuncPtr);
DWORD dummy = 0;
VirtualProtect(func, sizeof(func), oldProtect, &dummy);
}
}
void removeShim(HMODULE module, const char* funcName)
{
void* shimFunc = GetProcAddress(module, funcName);
if (shimFunc)
{
void* realFunc = getProcAddress(module, funcName);
if (realFunc && shimFunc != realFunc)
{
static std::list<void*> shimFuncs;
shimFuncs.push_back(shimFunc);
std::string shimFuncName("[shim]");
shimFuncName += funcName;
hookFunction(shimFuncs.back(), realFunc, shimFuncName.c_str());
}
}
}
}