diff --git a/DDrawCompat/Common/Hook.h b/DDrawCompat/Common/Hook.h index 4480666..4e8dd0d 100644 --- a/DDrawCompat/Common/Hook.h +++ b/DDrawCompat/Common/Hook.h @@ -8,6 +8,10 @@ #define HOOK_FUNCTION(module, func, newFunc) \ Compat::hookFunction(#module, #func, &newFunc) +#define HOOK_SHIM_FUNCTION(func, newFunc) \ + Compat::hookFunction( \ + reinterpret_cast(Compat::getOrigFuncPtr()), newFunc); + namespace Compat { diff --git a/DDrawCompat/DDraw/Hooks.cpp b/DDrawCompat/DDraw/Hooks.cpp index a4a5aa3..1a77c60 100644 --- a/DDrawCompat/DDraw/Hooks.cpp +++ b/DDrawCompat/DDraw/Hooks.cpp @@ -13,6 +13,7 @@ #include "DDraw/RealPrimarySurface.h" #include "DDraw/Repository.h" #include "Dll/Procs.h" +#include "Win32/Registry.h" namespace { @@ -83,6 +84,11 @@ namespace DDraw { void installHooks() { + Win32::Registry::unsetValue( + HKEY_LOCAL_MACHINE, "SOFTWARE\\Microsoft\\DirectDraw", "EmulationOnly"); + Win32::Registry::unsetValue( + HKEY_LOCAL_MACHINE, "SOFTWARE\\WOW6432Node\\Microsoft\\DirectDraw", "EmulationOnly"); + CompatPtr dd; CALL_ORIG_PROC(DirectDrawCreate, nullptr, &dd.getRef(), nullptr); if (!dd) diff --git a/DDrawCompat/Dll/DllMain.cpp b/DDrawCompat/Dll/DllMain.cpp index 673fd5e..00376df 100644 --- a/DDrawCompat/Dll/DllMain.cpp +++ b/DDrawCompat/Dll/DllMain.cpp @@ -30,6 +30,8 @@ namespace static bool isAlreadyInstalled = false; if (!isAlreadyInstalled) { + Compat::Log() << "Installing registry hooks"; + Win32::Registry::installHooks(); Compat::Log() << "Installing Direct3D driver hooks"; D3dDdi::installHooks(); Compat::Log() << "Installing DirectDraw hooks"; @@ -38,8 +40,6 @@ namespace Direct3d::installHooks(); Compat::Log() << "Installing GDI hooks"; Gdi::installHooks(); - Compat::Log() << "Installing registry hooks"; - Win32::Registry::installHooks(); Compat::Log() << "Finished installing hooks"; isAlreadyInstalled = true; } diff --git a/DDrawCompat/Win32/Registry.cpp b/DDrawCompat/Win32/Registry.cpp index 193155d..b16b9af 100644 --- a/DDrawCompat/Win32/Registry.cpp +++ b/DDrawCompat/Win32/Registry.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include @@ -36,19 +37,64 @@ namespace } }; - std::map g_registryOverride; + std::map g_dwordValues; + std::set g_unsetValues; - LSTATUS WINAPI regGetValueW(HKEY hkey, LPCWSTR lpSubKey, LPCWSTR lpValue, + CStringW getKeyName(HKEY key) + { + enum KEY_INFORMATION_CLASS + { + KeyBasicInformation = 0, + KeyNodeInformation = 1, + KeyFullInformation = 2, + KeyNameInformation = 3, + KeyCachedInformation = 4, + KeyFlagsInformation = 5, + KeyVirtualizationInformation = 6, + KeyHandleTagsInformation = 7, + MaxKeyInfoClass = 8 + }; + + typedef NTSTATUS(WINAPI *NtQueryKeyFuncPtr)( + HANDLE KeyHandle, + KEY_INFORMATION_CLASS KeyInformationClass, + PVOID KeyInformation, + ULONG Length, + PULONG ResultLength); + + static NtQueryKeyFuncPtr ntQueryKey = reinterpret_cast( + GetProcAddress(GetModuleHandle("ntdll"), "NtQueryKey")); + + if (ntQueryKey) + { + struct KEY_NAME_INFORMATION + { + ULONG NameLength; + WCHAR Name[256]; + }; + + KEY_NAME_INFORMATION keyName = {}; + ULONG resultSize = 0; + if (SUCCEEDED(ntQueryKey(key, KeyNameInformation, &keyName, sizeof(keyName), &resultSize))) + { + return CStringW(keyName.Name, keyName.NameLength / 2); + } + } + + return CStringW(); + } + + LONG WINAPI regGetValueW(HKEY hkey, LPCWSTR lpSubKey, LPCWSTR lpValue, DWORD dwFlags, LPDWORD pdwType, PVOID pvData, LPDWORD pcbData) { Compat::LogEnter("regGetValueW", hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData); - LSTATUS result = ERROR_SUCCESS; + LONG result = ERROR_SUCCESS; const auto it = hkey && lpSubKey && lpValue && (dwFlags & RRF_RT_REG_DWORD) - ? g_registryOverride.find(RegistryKey(hkey, lpSubKey, lpValue)) - : g_registryOverride.end(); + ? g_dwordValues.find(RegistryKey(hkey, lpSubKey, lpValue)) + : g_dwordValues.end(); - if (it != g_registryOverride.end()) + if (it != g_dwordValues.end()) { if (pdwType) { @@ -84,6 +130,35 @@ namespace Compat::LogLeave("regGetValueW", hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData) << result; return result; } + + LONG WINAPI regQueryValueExA(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, + LPBYTE lpData, LPDWORD lpcbData) + { + Compat::LogEnter("regQueryValueExA", hKey, lpValueName, lpReserved, lpType, + static_cast(lpData), lpcbData); + + if (hKey && lpValueName) + { + const CStringW keyName = getKeyName(hKey); + const CStringW localMachinePrefix = "\\REGISTRY\\MACHINE\\"; + if (localMachinePrefix == keyName.Mid(0, localMachinePrefix.GetLength())) + { + auto it = g_unsetValues.find(RegistryKey(HKEY_LOCAL_MACHINE, + keyName.Mid(localMachinePrefix.GetLength()), lpValueName)); + if (it != g_unsetValues.end()) + { + return ERROR_FILE_NOT_FOUND; + } + } + } + + LONG result = CALL_ORIG_FUNC(RegQueryValueExA)(hKey, lpValueName, lpReserved, lpType, + lpData, lpcbData); + + Compat::LogLeave("regQueryValueExA", hKey, lpValueName, lpReserved, lpType, + static_cast(lpData), lpcbData) << result; + return result; + } } namespace Win32 @@ -93,12 +168,19 @@ namespace Win32 void installHooks() { HOOK_FUNCTION(KernelBase, RegGetValueW, regGetValueW); + HOOK_SHIM_FUNCTION(RegQueryValueExA, regQueryValueExA); } void setValue(HKEY key, const char* subKey, const char* valueName, DWORD value) { assert(key && subKey && valueName); - g_registryOverride[RegistryKey(key, subKey, valueName)] = value; + g_dwordValues[RegistryKey(key, subKey, valueName)] = value; + } + + void unsetValue(HKEY key, const char* subKey, const char* valueName) + { + assert(key && subKey && valueName); + g_unsetValues.insert(RegistryKey(key, subKey, valueName)); } } } diff --git a/DDrawCompat/Win32/Registry.h b/DDrawCompat/Win32/Registry.h index 4f994e3..6705b18 100644 --- a/DDrawCompat/Win32/Registry.h +++ b/DDrawCompat/Win32/Registry.h @@ -10,5 +10,6 @@ namespace Win32 { void installHooks(); void setValue(HKEY key, const char* subKey, const char* valueName, DWORD value); + void unsetValue(HKEY key, const char* subKey, const char* valueName); } }