From e4f8cc97d9f014dfe00a6b225b67ff04f46a6676 Mon Sep 17 00:00:00 2001 From: narzoul Date: Sun, 26 Apr 2020 23:31:33 +0200 Subject: [PATCH] Disable DWM 8/16 bit mitigation display setting hooks Fixes upside-down icons with incorrect colors in Star Wars Rebellion (issue #22). --- DDrawCompat/Common/Hook.cpp | 190 +++++++++++++----------------- DDrawCompat/Common/Hook.h | 5 +- DDrawCompat/Dll/DllMain.cpp | 10 +- DDrawCompat/Win32/DisplayMode.cpp | 46 +++++--- DDrawCompat/Win32/DisplayMode.h | 1 - 5 files changed, 112 insertions(+), 140 deletions(-) diff --git a/DDrawCompat/Common/Hook.cpp b/DDrawCompat/Common/Hook.cpp index 718832d..2b90452 100644 --- a/DDrawCompat/Common/Hook.cpp +++ b/DDrawCompat/Common/Hook.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -12,8 +11,8 @@ #include #include -#include "Common/Hook.h" -#include "Common/Log.h" +#include +#include namespace { @@ -26,12 +25,73 @@ namespace std::map g_hookedFunctions; + PIMAGE_NT_HEADERS getImageNtHeaders(HMODULE module); + HMODULE getModuleHandleFromAddress(void* address); + std::filesystem::path getModulePath(HMODULE module); + std::map::iterator findOrigFunc(void* origFunc) { return std::find_if(g_hookedFunctions.begin(), g_hookedFunctions.end(), [=](const auto& i) { return origFunc == i.first || origFunc == i.second.origFunction; }); } + FARPROC* findProcAddressInIat(HMODULE module, const char* importedModuleName, const char* procName) + { + if (!module || !importedModuleName || !procName) + { + return nullptr; + } + + PIMAGE_NT_HEADERS ntHeaders = getImageNtHeaders(module); + if (!ntHeaders) + { + return nullptr; + } + + char* moduleBase = reinterpret_cast(module); + PIMAGE_IMPORT_DESCRIPTOR importDesc = reinterpret_cast(moduleBase + + ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress); + + for (PIMAGE_IMPORT_DESCRIPTOR desc = importDesc; + 0 != desc->Characteristics && 0xFFFF != desc->Name; + ++desc) + { + if (0 != _stricmp(moduleBase + desc->Name, importedModuleName)) + { + continue; + } + + auto thunk = reinterpret_cast(moduleBase + desc->FirstThunk); + auto origThunk = reinterpret_cast(moduleBase + desc->OriginalFirstThunk); + while (0 != thunk->u1.AddressOfData && 0 != origThunk->u1.AddressOfData) + { + auto origImport = reinterpret_cast( + moduleBase + origThunk->u1.AddressOfData); + + if (0 == strcmp(origImport->Name, procName)) + { + return reinterpret_cast(&thunk->u1.Function); + } + + ++thunk; + ++origThunk; + } + + break; + } + + return nullptr; + } + + std::string funcAddrToStr(void* funcPtr) + { + std::ostringstream oss; + HMODULE module = getModuleHandleFromAddress(funcPtr); + oss << getModulePath(module).string() << "+0x" << std::hex << + reinterpret_cast(funcPtr) - reinterpret_cast(module); + return oss.str(); + } + std::vector getProcessModules(HANDLE process) { std::vector modules(10000); @@ -43,27 +103,6 @@ namespace return modules; } - std::set getIatHookFunctions(const char* moduleName, const char* funcName) - { - std::set hookFunctions; - if (!moduleName || !funcName) - { - return hookFunctions; - } - - auto modules = getProcessModules(GetCurrentProcess()); - for (auto module : modules) - { - FARPROC func = Compat::getProcAddressFromIat(module, moduleName, funcName); - if (func) - { - hookFunctions.insert(func); - } - } - - return hookFunctions; - } - PIMAGE_NT_HEADERS getImageNtHeaders(HMODULE module) { PIMAGE_DOS_HEADER dosHeader = reinterpret_cast(module); @@ -97,15 +136,6 @@ namespace return path; } - std::string funcAddrToStr(void* funcPtr) - { - std::ostringstream oss; - HMODULE module = getModuleHandleFromAddress(funcPtr); - oss << getModulePath(module).string() << "+0x" << std::hex << - reinterpret_cast(funcPtr) - reinterpret_cast(module); - return oss.str(); - } - void hookFunction(void*& origFuncPtr, void* newFuncPtr, const char* funcName) { const auto it = findOrigFunc(origFuncPtr); @@ -158,77 +188,6 @@ namespace namespace Compat { - void redirectIatHooks(const char* moduleName, const char* funcName, void* newFunc) - { - auto hookFunctions(getIatHookFunctions(moduleName, funcName)); - - for (auto hookFunc : hookFunctions) - { - HMODULE module = getModuleHandleFromAddress(hookFunc); - if (!module) - { - continue; - } - - std::string moduleBaseName(getModulePath(module).filename().string()); - if (0 != _stricmp(moduleBaseName.c_str(), moduleName)) - { - Compat::Log() << "Disabling external hook to " << funcName << " in " << moduleBaseName; - static std::list origFuncs; - origFuncs.push_back(hookFunc); - hookFunction(origFuncs.back(), newFunc, funcName); - } - } - } - - FARPROC* findProcAddressInIat(HMODULE module, const char* importedModuleName, const char* procName) - { - if (!module || !importedModuleName || !procName) - { - return nullptr; - } - - PIMAGE_NT_HEADERS ntHeaders = getImageNtHeaders(module); - if (!ntHeaders) - { - return nullptr; - } - - char* moduleBase = reinterpret_cast(module); - PIMAGE_IMPORT_DESCRIPTOR importDesc = reinterpret_cast(moduleBase + - ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress); - - for (PIMAGE_IMPORT_DESCRIPTOR desc = importDesc; - 0 != desc->Characteristics && 0xFFFF != desc->Name; - ++desc) - { - if (0 != _stricmp(moduleBase + desc->Name, importedModuleName)) - { - continue; - } - - auto thunk = reinterpret_cast(moduleBase + desc->FirstThunk); - auto origThunk = reinterpret_cast(moduleBase + desc->OriginalFirstThunk); - while (0 != thunk->u1.AddressOfData && 0 != origThunk->u1.AddressOfData) - { - auto origImport = reinterpret_cast( - moduleBase + origThunk->u1.AddressOfData); - - if (0 == strcmp(origImport->Name, procName)) - { - return reinterpret_cast(&thunk->u1.Function); - } - - ++thunk; - ++origThunk; - } - - break; - } - - return nullptr; - } - FARPROC getProcAddress(HMODULE module, const char* procName) { if (!module || !procName) @@ -314,12 +273,6 @@ namespace Compat return reinterpret_cast(func); } - FARPROC getProcAddressFromIat(HMODULE module, const char* importedModuleName, const char* procName) - { - FARPROC* proc = findProcAddressInIat(module, importedModuleName, procName); - return proc ? *proc : nullptr; - } - void hookFunction(void*& origFuncPtr, void* newFuncPtr, const char* funcName) { ::hookFunction(origFuncPtr, newFuncPtr, funcName); @@ -363,6 +316,23 @@ namespace Compat } } + void removeShim(HMODULE module, const char* funcName) + { + void* shimFunc = GetProcAddress(module, funcName); + if (shimFunc) + { + void* realFunc = getProcAddress(module, funcName); + if (realFunc && shimFunc != realFunc) + { + static std::list shimFuncs; + shimFuncs.push_back(shimFunc); + std::string shimFuncName("[shim]"); + shimFuncName += funcName; + hookFunction(shimFuncs.back(), realFunc, shimFuncName.c_str()); + } + } + } + void unhookAllFunctions() { while (!g_hookedFunctions.empty()) diff --git a/DDrawCompat/Common/Hook.h b/DDrawCompat/Common/Hook.h index e37406e..895d906 100644 --- a/DDrawCompat/Common/Hook.h +++ b/DDrawCompat/Common/Hook.h @@ -13,8 +13,6 @@ namespace Compat { - void redirectIatHooks(const char* moduleName, const char* funcName, void* newFunc); - template OrigFuncPtr& getOrigFuncPtr() { @@ -22,9 +20,7 @@ namespace Compat return origFuncPtr; } - FARPROC* findProcAddressInIat(HMODULE module, const char* importedModuleName, const char* procName); FARPROC getProcAddress(HMODULE module, const char* procName); - FARPROC getProcAddressFromIat(HMODULE module, const char* importedModuleName, const char* procName); void hookFunction(void*& origFuncPtr, void* newFuncPtr, const char* funcName); void hookFunction(HMODULE module, const char* funcName, void*& origFuncPtr, void* newFuncPtr); void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr); @@ -37,6 +33,7 @@ namespace Compat reinterpret_cast(getOrigFuncPtr()), newFuncPtr); } + void removeShim(HMODULE module, const char* funcName); void unhookAllFunctions(); void unhookFunction(void* origFunc); } diff --git a/DDrawCompat/Dll/DllMain.cpp b/DDrawCompat/Dll/DllMain.cpp index aa8ba92..f19ad6e 100644 --- a/DDrawCompat/Dll/DllMain.cpp +++ b/DDrawCompat/Dll/DllMain.cpp @@ -32,13 +32,13 @@ namespace static bool isAlreadyInstalled = false; if (!isAlreadyInstalled) { - Win32::DisplayMode::disableDwm8And16BitMitigation(); + Compat::Log() << "Installing display mode hooks"; + Win32::DisplayMode::installHooks(); Compat::Log() << "Installing registry hooks"; Win32::Registry::installHooks(); Compat::Log() << "Installing Direct3D driver hooks"; D3dDdi::installHooks(g_origDDrawModule); - Compat::Log() << "Installing display mode hooks"; - Win32::DisplayMode::installHooks(); + Compat::Log() << "Installing Win32 hooks"; Win32::TimeFunctions::installHooks(); Win32::WaitFunctions::installHooks(); Gdi::VirtualScreen::init(); @@ -144,10 +144,6 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) SetProcessDPIAware(); SetThemeAppProperties(0); - Compat::redirectIatHooks("ddraw.dll", "DirectDrawCreate", - Compat::getProcAddress(hinstDLL, "DirectDrawCreate")); - Compat::redirectIatHooks("ddraw.dll", "DirectDrawCreateEx", - Compat::getProcAddress(hinstDLL, "DirectDrawCreateEx")); Win32::FontSmoothing::g_origSystemSettings = Win32::FontSmoothing::getSystemSettings(); Win32::MsgHooks::installHooks(); Time::init(); diff --git a/DDrawCompat/Win32/DisplayMode.cpp b/DDrawCompat/Win32/DisplayMode.cpp index 4c776bb..523852a 100644 --- a/DDrawCompat/Win32/DisplayMode.cpp +++ b/DDrawCompat/Win32/DisplayMode.cpp @@ -1,13 +1,13 @@ #include #include -#include "Common/CompatPtr.h" -#include "Common/Hook.h" -#include "DDraw/DirectDraw.h" -#include "DDraw/ScopedThreadLock.h" -#include "Gdi/Gdi.h" -#include "Gdi/VirtualScreen.h" -#include "Win32/DisplayMode.h" +#include +#include +#include +#include +#include +#include +#include BOOL WINAPI DWM8And16Bit_IsShimApplied_CallOut() { return FALSE; }; @@ -34,13 +34,11 @@ namespace DWORD g_currentBpp = 0; DWORD g_lastBpp = 0; - BOOL WINAPI enumDisplaySettingsExA( - LPCSTR lpszDeviceName, DWORD iModeNum, DEVMODEA* lpDevMode, DWORD dwFlags); - BOOL WINAPI enumDisplaySettingsExW( - LPCWSTR lpszDeviceName, DWORD iModeNum, DEVMODEW* lpDevMode, DWORD dwFlags); + BOOL WINAPI dwm8And16BitIsShimAppliedCallOut(); + BOOL WINAPI enumDisplaySettingsExA(LPCSTR lpszDeviceName, DWORD iModeNum, DEVMODEA* lpDevMode, DWORD dwFlags); + BOOL WINAPI enumDisplaySettingsExW(LPCWSTR lpszDeviceName, DWORD iModeNum, DEVMODEW* lpDevMode, DWORD dwFlags); - template + template LONG changeDisplaySettingsEx( ChangeDisplaySettingsExFunc origChangeDisplaySettingsEx, EnumDisplaySettingsExFunc origEnumDisplaySettingsEx, @@ -121,6 +119,21 @@ namespace lpszDeviceName, lpDevMode, hwnd, dwflags, lParam)); } + void disableDwm8And16BitMitigation() + { + auto user32 = GetModuleHandle("user32"); + Compat::removeShim(user32, "ChangeDisplaySettingsA"); + Compat::removeShim(user32, "ChangeDisplaySettingsW"); + Compat::removeShim(user32, "ChangeDisplaySettingsExA"); + Compat::removeShim(user32, "ChangeDisplaySettingsExW"); + Compat::removeShim(user32, "EnumDisplaySettingsA"); + Compat::removeShim(user32, "EnumDisplaySettingsW"); + Compat::removeShim(user32, "EnumDisplaySettingsExA"); + Compat::removeShim(user32, "EnumDisplaySettingsExW"); + + HOOK_FUNCTION(apphelp, DWM8And16Bit_IsShimApplied_CallOut, dwm8And16BitIsShimAppliedCallOut); + } + BOOL WINAPI dwm8And16BitIsShimAppliedCallOut() { return FALSE; @@ -261,11 +274,6 @@ namespace Win32 return ddQueryDisplaySettingsUniqueness(); } - void disableDwm8And16BitMitigation() - { - HOOK_FUNCTION(apphelp, DWM8And16Bit_IsShimApplied_CallOut, dwm8And16BitIsShimAppliedCallOut); - } - void installHooks() { DEVMODEA devMode = {}; @@ -286,6 +294,8 @@ namespace Win32 HOOK_FUNCTION(user32, EnumDisplaySettingsExA, enumDisplaySettingsExA); HOOK_FUNCTION(user32, EnumDisplaySettingsExW, enumDisplaySettingsExW); HOOK_FUNCTION(gdi32, GetDeviceCaps, getDeviceCaps); + + disableDwm8And16BitMitigation(); } } } diff --git a/DDrawCompat/Win32/DisplayMode.h b/DDrawCompat/Win32/DisplayMode.h index 5da13b1..a38972c 100644 --- a/DDrawCompat/Win32/DisplayMode.h +++ b/DDrawCompat/Win32/DisplayMode.h @@ -9,7 +9,6 @@ namespace Win32 DWORD getBpp(); ULONG queryDisplaySettingsUniqueness(); - void disableDwm8And16BitMitigation(); void installHooks(); } }