diff --git a/DDrawCompat/Common/Log.h b/DDrawCompat/Common/Log.h index 26e3103..f5d8b9a 100644 --- a/DDrawCompat/Common/Log.h +++ b/DDrawCompat/Common/Log.h @@ -290,7 +290,7 @@ namespace Compat auto& param = std::get(std::tie(params...)); if constexpr (IsString::value) { - if constexpr (!std::is_class_v) + if constexpr (std::is_pointer_v>) { if (reinterpret_cast(param) <= 0xFFFF) { diff --git a/DDrawCompat/DDraw/Hooks.cpp b/DDrawCompat/DDraw/Hooks.cpp index 8907645..79659fd 100644 --- a/DDrawCompat/DDraw/Hooks.cpp +++ b/DDrawCompat/DDraw/Hooks.cpp @@ -8,7 +8,6 @@ #include #include #include -#include namespace { @@ -100,11 +99,6 @@ namespace DDraw { RealPrimarySurface::init(); - Win32::Registry::unsetValue( - HKEY_LOCAL_MACHINE, "SOFTWARE\\Microsoft\\DirectDraw", "EmulationOnly"); - Win32::Registry::unsetValue( - HKEY_LOCAL_MACHINE, "SOFTWARE\\WOW6432Node\\Microsoft\\DirectDraw", "EmulationOnly"); - g_origInitialize = dd7.get()->lpVtbl->Initialize; Compat::hookFunction(reinterpret_cast(g_origInitialize), initialize, "IDirectDrawVtbl::Initialize"); diff --git a/DDrawCompat/Win32/Registry.cpp b/DDrawCompat/Win32/Registry.cpp index d674578..c0b6d14 100644 --- a/DDrawCompat/Win32/Registry.cpp +++ b/DDrawCompat/Win32/Registry.cpp @@ -1,8 +1,8 @@ -#include -#include #include -#include -#include +#include +#include + +#include #include #include @@ -10,42 +10,109 @@ typedef long NTSTATUS; +namespace Compat +{ + Log& operator<<(Log& os, HKEY hkey); +} + namespace { - struct RegistryKey + struct RegValue { - HKEY key; - std::wstring subKey; - std::wstring value; - - RegistryKey(HKEY key, const std::wstring& subKey, const std::wstring& value) - : key(key), subKey(subKey), value(value) - { - } - - bool operator<(const RegistryKey& rhs) const - { - if (key < rhs.key) { return true; } - if (key > rhs.key) { return false; } - const int subKeyComp = lstrcmpiW(subKey.c_str(), rhs.subKey.c_str()); - if (subKeyComp < 0) { return true; } - if (subKeyComp > 0) { return false; } - return lstrcmpiW(value.c_str(), rhs.value.c_str()) < 0; - } - - bool operator==(const RegistryKey& rhs) const - { - return key == rhs.key && - 0 == lstrcmpiW(subKey.c_str(), rhs.subKey.c_str()) && - 0 == lstrcmpiW(value.c_str(), rhs.value.c_str()); - } + DWORD type; + union { + const wchar_t* str; + }; }; - std::map g_dwordValues; - std::set g_unsetValues; + struct RegSz : RegValue + { + RegSz(const wchar_t* value) : RegValue{ REG_SZ, value } {} + }; + + struct RegEntry + { + const wchar_t* keyName; + const wchar_t* valueName; + RegValue value; + }; + + template + const char* g_funcName = nullptr; + + std::map g_openKeys; + + const std::map g_predefinedKeys = { +#define PREDEFINED_KEY_NAME_PAIR(key) { key, L#key } + PREDEFINED_KEY_NAME_PAIR(HKEY_CLASSES_ROOT), + PREDEFINED_KEY_NAME_PAIR(HKEY_CURRENT_CONFIG), + PREDEFINED_KEY_NAME_PAIR(HKEY_CURRENT_USER), + PREDEFINED_KEY_NAME_PAIR(HKEY_CURRENT_USER_LOCAL_SETTINGS), + PREDEFINED_KEY_NAME_PAIR(HKEY_DYN_DATA), + PREDEFINED_KEY_NAME_PAIR(HKEY_LOCAL_MACHINE), + PREDEFINED_KEY_NAME_PAIR(HKEY_PERFORMANCE_DATA), + PREDEFINED_KEY_NAME_PAIR(HKEY_PERFORMANCE_NLSTEXT), + PREDEFINED_KEY_NAME_PAIR(HKEY_PERFORMANCE_TEXT), + PREDEFINED_KEY_NAME_PAIR(HKEY_USERS) +#undef PREDEFINED_KEY_NAME_PAIR + }; + + const std::vector g_regEntries = { + { L"HKEY_LOCAL_MACHINE\\Software\\Microsoft\\DirectDraw", L"EmulationOnly", {} }, + { L"HKEY_LOCAL_MACHINE\\Software\\Microsoft\\Windows NT\\CurrentVersion\\DRIVERS32", L"vidc.iv31", RegSz(L"ir32_32.dll") }, + { L"HKEY_LOCAL_MACHINE\\Software\\Microsoft\\Windows NT\\CurrentVersion\\DRIVERS32", L"vidc.iv41", RegSz(L"ir41_32.ax") }, + { L"HKEY_LOCAL_MACHINE\\Software\\Microsoft\\Windows NT\\CurrentVersion\\DRIVERS32", L"vidc.iv50", RegSz(L"ir50_32.dll") }, + }; + +#undef HKLM_SOFTWARE_KEY + + bool filterType(DWORD type, const DWORD* flags) + { + if (!flags) + { + return true; + } + + switch (type) + { + case REG_SZ: + return *flags & RRF_RT_REG_SZ; + } + + return false; + } + + template + HKEY* getHKeyPtr(HKEY* hkey, Params...) + { + return hkey; + } + + template + HKEY* getHKeyPtr(FirstParam, Params... params) + { + return getHKeyPtr(params...); + } std::wstring getKeyName(HKEY key) { + if (!key) + { + return {}; + } + + auto it = g_predefinedKeys.find(key); + if (it != g_predefinedKeys.end()) + { + return it->second; + } + + it = g_openKeys.find(key); + if (it != g_openKeys.end()) + { + return it->second; + } + enum KEY_INFORMATION_CLASS { KeyBasicInformation = 0, @@ -59,7 +126,7 @@ namespace MaxKeyInfoClass = 8 }; - typedef NTSTATUS(WINAPI *NtQueryKeyFuncPtr)( + typedef NTSTATUS(WINAPI* NtQueryKeyFuncPtr)( HANDLE KeyHandle, KEY_INFORMATION_CLASS KeyInformationClass, PVOID KeyInformation, @@ -68,8 +135,8 @@ namespace static NtQueryKeyFuncPtr ntQueryKey = reinterpret_cast( GetProcAddress(GetModuleHandle("ntdll"), "NtQueryKey")); - - if (ntQueryKey) + + if (ntQueryKey && key) { struct KEY_NAME_INFORMATION { @@ -84,110 +151,279 @@ namespace return std::wstring(keyName.Name, keyName.NameLength / 2); } } - - return std::wstring(); + + return {}; } - LONG WINAPI regGetValueW(HKEY hkey, LPCWSTR lpSubKey, LPCWSTR lpValue, - DWORD dwFlags, LPDWORD pdwType, PVOID pvData, LPDWORD pcbData) + std::size_t getLength(const char* str) { - LOG_FUNC("regGetValueW", hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData); - LONG result = ERROR_SUCCESS; + return strlen(str); + } - const auto it = hkey && lpSubKey && lpValue && (dwFlags & RRF_RT_REG_DWORD) - ? g_dwordValues.find(RegistryKey(hkey, lpSubKey, lpValue)) - : g_dwordValues.end(); + std::size_t getLength(const wchar_t* str) + { + return wcslen(str); + } - if (it != g_dwordValues.end()) + const RegValue* getValue(const std::wstring& keyName, const wchar_t* valueName) + { + if (!valueName) { - if (pdwType) + valueName = L""; + } + for (const auto& regEntry : g_regEntries) + { + if (0 == lstrcmpiW(valueName, regEntry.valueName) && + 0 == lstrcmpiW(keyName.c_str(), regEntry.keyName)) { - *pdwType = REG_DWORD; + return ®Entry.value; + } + } + return nullptr; + } + + const RegValue* getValue(const std::wstring& keyName, const char* valueName) + { + if (!valueName) + { + return getValue(keyName, L""); + } + std::wstring convertedValueName(valueName, valueName + strlen(valueName)); + return getValue(keyName, convertedValueName.c_str()); + } + + template + LONG getValue(HKEY hkey, const Char* subKeyName, const Char* valueName, + const DWORD* flags, DWORD* type, void* data, DWORD* length) + { + if (data && !length) + { + return -1; + } + + auto keyName(getKeyName(hkey)); + if (keyName.empty()) + { + return -1; + } + + if (subKeyName) + { + keyName += L'\\'; + keyName.append(subKeyName, subKeyName + getLength(subKeyName)); + } + + const RegValue* value = getValue(keyName, valueName); + if (!value) + { + return -1; + } + + if (REG_NONE == value->type) + { + return ERROR_FILE_NOT_FOUND; + } + + if (!filterType(value->type, flags)) + { + return -1; + } + + if (type) + { + *type = value->type; + } + + if (!length) + { + return ERROR_SUCCESS; + } + + const void* src = nullptr; + const DWORD maxLength = *length; + + switch (value->type) + { + case REG_SZ: + src = value->str; + *length = (getLength(value->str) + (flags ? 1 : 0)) * sizeof(Char); + break; + + default: + *length = 0; + break; + } + + if (data) + { + if (*length > maxLength) + { + if (flags && (*flags & RRF_ZEROONFAILURE)) + { + memset(data, 0, *length); + } + return ERROR_MORE_DATA; } - if (pvData) - { - if (!pcbData) - { - result = ERROR_INVALID_PARAMETER; - } - else if (*pcbData >= sizeof(DWORD)) - { - std::memcpy(pvData, &it->second, sizeof(DWORD)); - } - else - { - result = ERROR_MORE_DATA; - } - } + memcpy(data, src, *length); + } - if (pcbData) - { - *pcbData = sizeof(DWORD); - } + return ERROR_SUCCESS; + } + + void hookRegistryFunction(void*& origFuncPtr, void* newFuncPtr, const char* funcName) + { + auto kernelBase = LoadLibrary("kernelbase"); + if (kernelBase && Compat::getProcAddress(kernelBase, funcName)) + { + Compat::hookFunction(kernelBase, funcName, origFuncPtr, newFuncPtr); } else { - result = CALL_ORIG_FUNC(RegGetValueW)(hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData); + Compat::hookFunction("advapi32", funcName, origFuncPtr, newFuncPtr); } + } + LSTATUS WINAPI regCloseKey(HKEY hKey) + { + LOG_FUNC("RegCloseKey", hKey); + const auto result = CALL_ORIG_FUNC(RegCloseKey)(hKey); + if (ERROR_SUCCESS == result) + { + g_openKeys.erase(hKey); + } return LOG_RESULT(result); } + template + LONG regGetValue(HKEY hkey, const Char* lpSubKey, const Char* lpValue, + DWORD dwFlags, LPDWORD pdwType, PVOID pvData, LPDWORD pcbData, OrigFunc origFunc, const char* funcName) + { + LOG_FUNC(funcName, hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData); + const auto result = getValue(hkey, lpSubKey, lpValue, &dwFlags, pdwType, pvData, pcbData); + if (-1 != result) + { + return LOG_RESULT(result); + } + return LOG_RESULT(origFunc(hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData)); + } + + LONG WINAPI regGetValueA(HKEY hkey, LPCSTR lpSubKey, LPCSTR lpValue, + DWORD dwFlags, LPDWORD pdwType, PVOID pvData, LPDWORD pcbData) + { + return regGetValue(hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData, + CALL_ORIG_FUNC(RegGetValueA), "RegGetValueA"); + } + + LONG WINAPI regGetValueW(HKEY hkey, LPCWSTR lpSubKey, LPCWSTR lpValue, + DWORD dwFlags, LPDWORD pdwType, PVOID pvData, LPDWORD pcbData) + { + return regGetValue(hkey, lpSubKey, lpValue, dwFlags, pdwType, pvData, pcbData, + CALL_ORIG_FUNC(RegGetValueW), "RegGetValueW"); + } + + template + LSTATUS WINAPI regOpenKey(HKEY hKey, const Char* lpSubKey, Params... params) + { + LOG_FUNC(g_funcName, hKey, lpSubKey, params...); + const auto result = Compat::g_origFuncPtr(hKey, lpSubKey, params...); + if (ERROR_SUCCESS == result) + { + const auto hkeyPtr = getHKeyPtr(params...); + if (hkeyPtr) + { + auto keyName(getKeyName(hKey)); + if (lpSubKey) + { + keyName += L'\\'; + keyName.append(lpSubKey, lpSubKey + getLength(lpSubKey)); + } + g_openKeys[*hkeyPtr] = keyName; + } + } + return LOG_RESULT(result); + } + + template + LONG regQueryValueEx(HKEY hKey, const Char* lpValueName, LPDWORD lpReserved, LPDWORD lpType, + LPBYTE lpData, LPDWORD lpcbData, OrigFunc origFunc, const char* funcName) + { + LOG_FUNC(funcName, hKey, lpValueName, lpReserved, lpType, static_cast(lpData), lpcbData); + const auto result = getValue(hKey, nullptr, lpValueName, nullptr, lpType, lpData, lpcbData); + if (-1 != result) + { + return LOG_RESULT(result); + } + return LOG_RESULT(origFunc(hKey, lpValueName, lpReserved, lpType, lpData, lpcbData)); + } + LONG WINAPI regQueryValueExA(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData, LPDWORD lpcbData) { - LOG_FUNC("regQueryValueExA", hKey, lpValueName, lpReserved, lpType, static_cast(lpData), lpcbData); + return regQueryValueEx(hKey, lpValueName, lpReserved, lpType, lpData, lpcbData, + CALL_ORIG_FUNC(RegQueryValueExA), "RegQueryValueExA"); + } - if (hKey && lpValueName) - { - const std::wstring keyName = getKeyName(hKey); - const std::wstring localMachinePrefix = L"\\REGISTRY\\MACHINE\\"; - if (localMachinePrefix == keyName.substr(0, localMachinePrefix.size())) - { - std::wostringstream oss; - oss << lpValueName; - auto it = g_unsetValues.find(RegistryKey(HKEY_LOCAL_MACHINE, - keyName.substr(localMachinePrefix.size()), oss.str())); - if (it != g_unsetValues.end()) - { - return LOG_RESULT(ERROR_FILE_NOT_FOUND); - } - } - } - - return LOG_RESULT(CALL_ORIG_FUNC(RegQueryValueExA)(hKey, lpValueName, lpReserved, lpType, lpData, lpcbData)); + LONG WINAPI regQueryValueExW(HKEY hKey, LPCWSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, + LPBYTE lpData, LPDWORD lpcbData) + { + return regQueryValueEx(hKey, lpValueName, lpReserved, lpType, lpData, lpcbData, + CALL_ORIG_FUNC(RegQueryValueExW), "RegQueryValueExW"); } } +namespace Compat +{ + Log& operator<<(Log& os, HKEY hkey) + { + auto it = g_predefinedKeys.find(hkey); + if (it != g_predefinedKeys.end()) + { + return os << it->second.c_str(); + } + + os << "HKEY(" << static_cast(hkey); + auto keyName(getKeyName(hkey)); + if (!keyName.empty()) + { + os << ',' << '"' << keyName.c_str() << '"'; + } + return os << ')'; + } +} + +#define HOOK_REGISTRY_FUNCTION(func, newFunc) \ + hookRegistryFunction(reinterpret_cast(Compat::g_origFuncPtr<&func>), static_cast(newFunc), #func) + +#define HOOK_REGISTRY_OPEN_FUNCTION(func) \ + g_funcName = #func; \ + HOOK_REGISTRY_FUNCTION(func, regOpenKey) + namespace Win32 { namespace Registry { void installHooks() { - HOOK_SHIM_FUNCTION(RegGetValueW, regGetValueW); - HOOK_SHIM_FUNCTION(RegQueryValueExA, regQueryValueExA); - } + HOOK_REGISTRY_OPEN_FUNCTION(RegCreateKeyA); + HOOK_REGISTRY_OPEN_FUNCTION(RegCreateKeyW); + HOOK_REGISTRY_OPEN_FUNCTION(RegCreateKeyExA); + HOOK_REGISTRY_OPEN_FUNCTION(RegCreateKeyExW); + HOOK_REGISTRY_OPEN_FUNCTION(RegCreateKeyTransactedA); + HOOK_REGISTRY_OPEN_FUNCTION(RegCreateKeyTransactedW); - void setValue(HKEY key, const char* subKey, const char* valueName, DWORD value) - { - assert(key && subKey && valueName); - std::wostringstream subKeyW; - subKeyW << subKey; - std::wostringstream valueNameW; - valueNameW << valueName; - g_dwordValues[RegistryKey(key, subKeyW.str(), valueNameW.str())] = value; - } + HOOK_REGISTRY_OPEN_FUNCTION(RegOpenKeyA); + HOOK_REGISTRY_OPEN_FUNCTION(RegOpenKeyW); + HOOK_REGISTRY_OPEN_FUNCTION(RegOpenKeyExA); + HOOK_REGISTRY_OPEN_FUNCTION(RegOpenKeyExW); + HOOK_REGISTRY_OPEN_FUNCTION(RegOpenKeyTransactedA); + HOOK_REGISTRY_OPEN_FUNCTION(RegOpenKeyTransactedW); - void unsetValue(HKEY key, const char* subKey, const char* valueName) - { - assert(key && subKey && valueName); - std::wostringstream subKeyW; - subKeyW << subKey; - std::wostringstream valueNameW; - valueNameW << valueName; - g_unsetValues.insert(RegistryKey(key, subKeyW.str(), valueNameW.str())); + HOOK_REGISTRY_FUNCTION(RegCloseKey, regCloseKey); + HOOK_REGISTRY_FUNCTION(RegGetValueA, regGetValueA); + HOOK_REGISTRY_FUNCTION(RegGetValueW, regGetValueW); + HOOK_REGISTRY_FUNCTION(RegQueryValueExA, regQueryValueExA); + HOOK_REGISTRY_FUNCTION(RegQueryValueExW, regQueryValueExW); } } } diff --git a/DDrawCompat/Win32/Registry.h b/DDrawCompat/Win32/Registry.h index da1aba2..293bb5f 100644 --- a/DDrawCompat/Win32/Registry.h +++ b/DDrawCompat/Win32/Registry.h @@ -1,13 +1,9 @@ #pragma once -#include - namespace Win32 { namespace Registry { void installHooks(); - void setValue(HKEY key, const char* subKey, const char* valueName, DWORD value); - void unsetValue(HKEY key, const char* subKey, const char* valueName); } }