From f7f5348a870d20e6abba060f348dbc691098f787 Mon Sep 17 00:00:00 2001 From: narzoul Date: Mon, 2 May 2016 22:16:40 +0200 Subject: [PATCH] Fixed/simplified hooking logic for COM methods --- DDrawCompat/CompatDirectDrawSurface.cpp | 91 ++++++++----------------- DDrawCompat/CompatDirectDrawSurface.h | 4 -- DDrawCompat/CompatPrimarySurface.cpp | 44 ++++++++++-- DDrawCompat/CompatPrimarySurface.h | 3 + DDrawCompat/CompatVtable.cpp | 6 -- DDrawCompat/CompatVtable.h | 53 +------------- DDrawCompat/DDrawCompat.vcxproj | 1 - DDrawCompat/DDrawCompat.vcxproj.filters | 3 - DDrawCompat/Hook.cpp | 25 +++++-- 9 files changed, 91 insertions(+), 139 deletions(-) delete mode 100644 DDrawCompat/CompatVtable.cpp diff --git a/DDrawCompat/CompatDirectDrawSurface.cpp b/DDrawCompat/CompatDirectDrawSurface.cpp index 0bc36a5..5e3ece2 100644 --- a/DDrawCompat/CompatDirectDrawSurface.cpp +++ b/DDrawCompat/CompatDirectDrawSurface.cpp @@ -236,8 +236,12 @@ HRESULT CompatDirectDrawSurface::createCompatPrimarySurface( return result; } - s_compatPrimarySurface = compatSurface; - initCompatPrimarySurface(); + IDirectDrawSurface7* compatSurface7 = nullptr; + s_origVtable.QueryInterface(compatSurface, IID_IDirectDrawSurface7, + reinterpret_cast(&compatSurface7)); + CompatPrimarySurface::setPrimary(compatSurface7); + CompatDirectDrawSurface::s_origVtable.Release(compatSurface7); + return DD_OK; } @@ -250,22 +254,6 @@ void CompatDirectDrawSurface::fixSurfacePtrs(TSurface& surface) surface7->lpVtbl->Release(surface7); } -template -void CompatDirectDrawSurface::initPrimarySurfacePtr(const GUID& guid, IUnknown& surface) -{ - if (SUCCEEDED(surface.lpVtbl->QueryInterface( - &surface, guid, reinterpret_cast(&s_compatPrimarySurface)))) - { - s_compatPrimarySurface->lpVtbl->Release(s_compatPrimarySurface); - } -} - -template -void CompatDirectDrawSurface::resetPrimarySurfacePtr() -{ - s_compatPrimarySurface = nullptr; -} - template HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Blt( TSurface* This, @@ -275,7 +263,8 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Blt( DWORD dwFlags, LPDDBLTFX lpDDBltFx) { - if ((This == s_compatPrimarySurface || lpDDSrcSurface == s_compatPrimarySurface) && + const bool isPrimaryDest = CompatPrimarySurface::isPrimary(This); + if ((isPrimaryDest || CompatPrimarySurface::isPrimary(lpDDSrcSurface)) && RealPrimarySurface::isLost()) { return DDERR_SURFACELOST; @@ -327,7 +316,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Blt( result = s_origVtable.Blt(This, lpDestRect, lpDDSrcSurface, lpSrcRect, dwFlags, lpDDBltFx); } - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (isPrimaryDest && SUCCEEDED(result)) { RealPrimarySurface::invalidate(lpDestRect); RealPrimarySurface::update(); @@ -345,14 +334,15 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::BltFast( LPRECT lpSrcRect, DWORD dwTrans) { - if ((This == s_compatPrimarySurface || lpDDSrcSurface == s_compatPrimarySurface) && + const bool isPrimaryDest = CompatPrimarySurface::isPrimary(This); + if ((isPrimaryDest || CompatPrimarySurface::isPrimary(lpDDSrcSurface)) && RealPrimarySurface::isLost()) { return DDERR_SURFACELOST; } HRESULT result = s_origVtable.BltFast(This, dwX, dwY, lpDDSrcSurface, lpSrcRect, dwTrans); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (isPrimaryDest && SUCCEEDED(result)) { const LONG x = dwX; const LONG y = dwY; @@ -383,7 +373,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Flip( DWORD dwFlags) { HRESULT result = s_origVtable.Flip(This, lpDDSurfaceTargetOverride, dwFlags); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (SUCCEEDED(result) && CompatPrimarySurface::isPrimary(This)) { result = RealPrimarySurface::flip(dwFlags); } @@ -396,7 +386,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::GetCaps( TDdsCaps* lpDDSCaps) { HRESULT result = s_origVtable.GetCaps(This, lpDDSCaps); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (SUCCEEDED(result) && CompatPrimarySurface::isPrimary(This)) { restorePrimaryCaps(*lpDDSCaps); } @@ -409,7 +399,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::GetSurfaceDesc( TSurfaceDesc* lpDDSurfaceDesc) { HRESULT result = s_origVtable.GetSurfaceDesc(This, lpDDSurfaceDesc); - if (This == s_compatPrimarySurface && SUCCEEDED(result) && !g_lockingPrimary) + if (SUCCEEDED(result) && !g_lockingPrimary && CompatPrimarySurface::isPrimary(This)) { restorePrimaryCaps(lpDDSurfaceDesc->ddsCaps); } @@ -420,7 +410,7 @@ template HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::IsLost(TSurface* This) { HRESULT result = s_origVtable.IsLost(This); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (SUCCEEDED(result) && CompatPrimarySurface::isPrimary(This)) { result = RealPrimarySurface::isLost() ? DDERR_SURFACELOST : DD_OK; } @@ -435,7 +425,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Lock( DWORD dwFlags, HANDLE hEvent) { - if (This == s_compatPrimarySurface) + if (CompatPrimarySurface::isPrimary(This)) { if (RealPrimarySurface::isLost()) { @@ -472,7 +462,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::QueryInterface( REFIID riid, LPVOID* obp) { - if (This == s_compatPrimarySurface && riid == IID_IDirectDrawGammaControl) + if (riid == IID_IDirectDrawGammaControl && CompatPrimarySurface::isPrimary(This)) { return RealPrimarySurface::getSurface()->lpVtbl->QueryInterface( RealPrimarySurface::getSurface(), riid, obp); @@ -483,13 +473,14 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::QueryInterface( template HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::ReleaseDC(TSurface* This, HDC hDC) { - if (This == s_compatPrimarySurface && RealPrimarySurface::isLost()) + const bool isPrimary = CompatPrimarySurface::isPrimary(This); + if (isPrimary && RealPrimarySurface::isLost()) { return DDERR_SURFACELOST; } HRESULT result = s_origVtable.ReleaseDC(This, hDC); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (isPrimary && SUCCEEDED(result)) { RealPrimarySurface::invalidate(nullptr); RealPrimarySurface::update(); @@ -508,7 +499,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Restore(TSurface* T { fixSurfacePtrs(*This); } - if (This == s_compatPrimarySurface) + if (CompatPrimarySurface::isPrimary(This)) { result = RealPrimarySurface::restore(); if (wasLost) @@ -526,7 +517,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::SetClipper( LPDIRECTDRAWCLIPPER lpDDClipper) { HRESULT result = s_origVtable.SetClipper(This, lpDDClipper); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (SUCCEEDED(result) && CompatPrimarySurface::isPrimary(This)) { RealPrimarySurface::setClipper(lpDDClipper); } @@ -538,7 +529,8 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::SetPalette( TSurface* This, LPDIRECTDRAWPALETTE lpDDPalette) { - if (This == s_compatPrimarySurface) + const bool isPrimary = CompatPrimarySurface::isPrimary(This); + if (isPrimary) { if (lpDDPalette) { @@ -551,7 +543,7 @@ HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::SetPalette( } HRESULT result = s_origVtable.SetPalette(This, lpDDPalette); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (isPrimary && SUCCEEDED(result)) { CompatPrimarySurface::palette = lpDDPalette; RealPrimarySurface::setPalette(); @@ -563,39 +555,13 @@ template HRESULT STDMETHODCALLTYPE CompatDirectDrawSurface::Unlock(TSurface* This, TUnlockParam lpRect) { HRESULT result = s_origVtable.Unlock(This, lpRect); - if (This == s_compatPrimarySurface && SUCCEEDED(result)) + if (SUCCEEDED(result) && CompatPrimarySurface::isPrimary(This)) { RealPrimarySurface::update(); } return result; } -template -void CompatDirectDrawSurface::initCompatPrimarySurface() -{ - Compat::LogEnter("CompatDirectDrawSurface::initCompatPrimarySurface"); - - IUnknown& unk = reinterpret_cast(*s_compatPrimarySurface); - CompatDirectDrawSurface::initPrimarySurfacePtr(IID_IDirectDrawSurface, unk); - CompatDirectDrawSurface::initPrimarySurfacePtr(IID_IDirectDrawSurface2, unk); - CompatDirectDrawSurface::initPrimarySurfacePtr(IID_IDirectDrawSurface3, unk); - CompatDirectDrawSurface::initPrimarySurfacePtr(IID_IDirectDrawSurface4, unk); - CompatDirectDrawSurface::initPrimarySurfacePtr(IID_IDirectDrawSurface7, unk); - - if (SUCCEEDED(s_origVtable.QueryInterface( - s_compatPrimarySurface, - IID_IDirectDrawSurface7, - reinterpret_cast(&CompatPrimarySurface::surface)))) - { - IReleaseNotifier* releaseNotifier = &CompatPrimarySurface::releaseNotifier; - CompatPrimarySurface::surface->lpVtbl->SetPrivateData(CompatPrimarySurface::surface, - IID_IReleaseNotifier, releaseNotifier, sizeof(releaseNotifier), DDSPD_IUNKNOWNPOINTER); - CompatPrimarySurface::surface->lpVtbl->Release(CompatPrimarySurface::surface); - } - - Compat::LogLeave("CompatDirectDrawSurface::initCompatPrimarySurface"); -} - template void CompatDirectDrawSurface::restorePrimaryCaps(TDdsCaps& caps) { @@ -603,9 +569,6 @@ void CompatDirectDrawSurface::restorePrimaryCaps(TDdsCaps& caps) caps.dwCaps |= DDSCAPS_PRIMARYSURFACE | DDSCAPS_VISIBLE; } -template -TSurface* CompatDirectDrawSurface::s_compatPrimarySurface = nullptr; - template <> const IID& CompatDirectDrawSurface::s_iid = IID_IDirectDrawSurface; template <> const IID& CompatDirectDrawSurface::s_iid = IID_IDirectDrawSurface2; template <> const IID& CompatDirectDrawSurface::s_iid = IID_IDirectDrawSurface3; diff --git a/DDrawCompat/CompatDirectDrawSurface.h b/DDrawCompat/CompatDirectDrawSurface.h index e3fe4ed..7a1e1dc 100644 --- a/DDrawCompat/CompatDirectDrawSurface.h +++ b/DDrawCompat/CompatDirectDrawSurface.h @@ -21,8 +21,6 @@ public: TSurface*& compatSurface); static void fixSurfacePtrs(TSurface& surface); - static void initPrimarySurfacePtr(const GUID& guid, IUnknown& surface); - static void resetPrimarySurfacePtr(); static HRESULT STDMETHODCALLTYPE Blt( TSurface* This, @@ -66,7 +64,5 @@ public: static const IID& s_iid; private: - static void initCompatPrimarySurface(); static void restorePrimaryCaps(TDdsCaps& caps); - static TSurface* s_compatPrimarySurface; }; diff --git a/DDrawCompat/CompatPrimarySurface.cpp b/DDrawCompat/CompatPrimarySurface.cpp index 2e35a97..652138e 100644 --- a/DDrawCompat/CompatPrimarySurface.cpp +++ b/DDrawCompat/CompatPrimarySurface.cpp @@ -1,3 +1,6 @@ +#include +#include + #include "CompatDirectDraw.h" #include "CompatDirectDrawSurface.h" #include "CompatPrimarySurface.h" @@ -6,10 +9,22 @@ namespace { + std::vector g_primarySurfacePtrs; + + void addPrimary(IDirectDrawSurface7* surface, const IID& iid) + { + IUnknown* intf = nullptr; + CompatDirectDrawSurface::s_origVtable.QueryInterface( + surface, iid, reinterpret_cast(&intf)); + g_primarySurfacePtrs.push_back(intf); + intf->lpVtbl->Release(intf); + } + void onRelease() { Compat::LogEnter("CompatPrimarySurface::onRelease"); + g_primarySurfacePtrs.clear(); CompatPrimarySurface::surface = nullptr; CompatPrimarySurface::palette = nullptr; CompatPrimarySurface::width = 0; @@ -17,12 +32,6 @@ namespace ZeroMemory(&CompatPrimarySurface::paletteEntries, sizeof(CompatPrimarySurface::paletteEntries)); ZeroMemory(&CompatPrimarySurface::pixelFormat, sizeof(CompatPrimarySurface::pixelFormat)); - CompatDirectDrawSurface::resetPrimarySurfacePtr(); - CompatDirectDrawSurface::resetPrimarySurfacePtr(); - CompatDirectDrawSurface::resetPrimarySurfacePtr(); - CompatDirectDrawSurface::resetPrimarySurfacePtr(); - CompatDirectDrawSurface::resetPrimarySurfacePtr(); - RealPrimarySurface::release(); Compat::LogLeave("CompatPrimarySurface::onRelease"); @@ -50,6 +59,29 @@ namespace CompatPrimarySurface template DisplayMode getDisplayMode(IDirectDraw4& dd); template DisplayMode getDisplayMode(IDirectDraw7& dd); + bool isPrimary(void* surfacePtr) + { + return g_primarySurfacePtrs.end() != + std::find(g_primarySurfacePtrs.begin(), g_primarySurfacePtrs.end(), surfacePtr); + } + + void setPrimary(IDirectDrawSurface7* surfacePtr) + { + surface = surfacePtr; + + g_primarySurfacePtrs.clear(); + g_primarySurfacePtrs.push_back(surfacePtr); + addPrimary(surfacePtr, IID_IDirectDrawSurface4); + addPrimary(surfacePtr, IID_IDirectDrawSurface3); + addPrimary(surfacePtr, IID_IDirectDrawSurface2); + addPrimary(surfacePtr, IID_IDirectDrawSurface); + + IReleaseNotifier* releaseNotifierPtr = &releaseNotifier; + CompatDirectDrawSurface::s_origVtable.SetPrivateData( + surfacePtr, IID_IReleaseNotifier, releaseNotifierPtr, sizeof(releaseNotifierPtr), + DDSPD_IUNKNOWNPOINTER); + } + DisplayMode displayMode = {}; bool isDisplayModeChanged = false; IDirectDrawSurface7* surface = nullptr; diff --git a/DDrawCompat/CompatPrimarySurface.h b/DDrawCompat/CompatPrimarySurface.h index 265c043..001ba60 100644 --- a/DDrawCompat/CompatPrimarySurface.h +++ b/DDrawCompat/CompatPrimarySurface.h @@ -19,6 +19,9 @@ namespace CompatPrimarySurface template DisplayMode getDisplayMode(TDirectDraw& dd); + bool isPrimary(void* surfacePtr); + void setPrimary(IDirectDrawSurface7* surfacePtr); + extern DisplayMode displayMode; extern bool isDisplayModeChanged; extern IDirectDrawSurface7* surface; diff --git a/DDrawCompat/CompatVtable.cpp b/DDrawCompat/CompatVtable.cpp deleted file mode 100644 index c70ee3d..0000000 --- a/DDrawCompat/CompatVtable.cpp +++ /dev/null @@ -1,6 +0,0 @@ -#include "CompatVtable.h" - -namespace Compat -{ - std::map g_hookedMethods; -} diff --git a/DDrawCompat/CompatVtable.h b/DDrawCompat/CompatVtable.h index db7c32f..9489eac 100644 --- a/DDrawCompat/CompatVtable.h +++ b/DDrawCompat/CompatVtable.h @@ -11,22 +11,6 @@ template using Vtable = typename std::remove_pointer::type; -namespace Compat -{ - struct HookedMethodInfo - { - HookedMethodInfo(void*& updatedOrigMethodPtr, std::map& vtablePtrToCompatVtable) - : updatedOrigMethodPtr(updatedOrigMethodPtr), vtablePtrToCompatVtable(vtablePtrToCompatVtable) - { - } - - void*& updatedOrigMethodPtr; - std::map& vtablePtrToCompatVtable; - }; - - extern std::map g_hookedMethods; -} - template class CompatVtable { @@ -65,7 +49,7 @@ private: else { s_threadSafeVtable.*ptr = getThreadSafeFuncPtr(s_compatVtable.*ptr); - hookMethod(reinterpret_cast(s_origVtable.*ptr), s_threadSafeVtable.*ptr); + Compat::hookFunction(reinterpret_cast(s_origVtable.*ptr), s_threadSafeVtable.*ptr); } } @@ -75,7 +59,7 @@ private: s_funcNames[getKey()] = vtableTypeName + "::" + funcName; s_threadSafeVtable.*ptr = getThreadSafeFuncPtr(s_compatVtable.*ptr); - hookMethod(reinterpret_cast(s_origVtable.*ptr), s_threadSafeVtable.*ptr); + Compat::hookFunction(reinterpret_cast(s_origVtable.*ptr), s_threadSafeVtable.*ptr); if (!(s_compatVtable.*ptr)) { @@ -101,23 +85,6 @@ private: return &threadSafeFunc; } - void hookMethod(void*& origMethodPtr, void* newMethodPtr) - { - auto it = Compat::g_hookedMethods.find(origMethodPtr); - if (it != Compat::g_hookedMethods.end()) - { - origMethodPtr = it->second.updatedOrigMethodPtr; - it->second.vtablePtrToCompatVtable[s_vtablePtr] = &s_compatVtable; - } - else - { - s_vtablePtrToCompatVtable[s_vtablePtr] = &s_compatVtable; - Compat::g_hookedMethods.emplace(origMethodPtr, - Compat::HookedMethodInfo(origMethodPtr, s_vtablePtrToCompatVtable)); - Compat::hookFunction(origMethodPtr, newMethodPtr); - } - } - template static Result STDMETHODCALLTYPE threadSafeFunc(IntfPtr This, Params... params) { @@ -126,17 +93,7 @@ private: Compat::LogEnter(s_funcNames[getKey()].c_str(), This, params...); #endif - Result result; - auto it = s_vtablePtrToCompatVtable.find(This->lpVtbl); - if (it != s_vtablePtrToCompatVtable.end()) - { - Vtable& compatVtable = *static_cast*>(it->second); - result = (compatVtable.*ptr)(This, params...); - } - else - { - result = (s_origVtable.*ptr)(This, params...); - } + Result result = (s_compatVtable.*ptr)(This, params...); #ifdef _DEBUG Compat::LogLeave(s_funcNames[getKey()].c_str(), This, params...) << result; @@ -162,7 +119,6 @@ private: static Vtable* s_vtablePtr; static Vtable s_compatVtable; static Vtable s_threadSafeVtable; - static std::map s_vtablePtrToCompatVtable; static std::map, std::string> s_funcNames; }; @@ -178,8 +134,5 @@ Vtable CompatVtable::s_compatVtable(Compa template Vtable CompatVtable::s_threadSafeVtable = {}; -template -std::map CompatVtable::s_vtablePtrToCompatVtable; - template std::map, std::string> CompatVtable::s_funcNames; diff --git a/DDrawCompat/DDrawCompat.vcxproj b/DDrawCompat/DDrawCompat.vcxproj index 5166c04..92d0739 100644 --- a/DDrawCompat/DDrawCompat.vcxproj +++ b/DDrawCompat/DDrawCompat.vcxproj @@ -194,7 +194,6 @@ - diff --git a/DDrawCompat/DDrawCompat.vcxproj.filters b/DDrawCompat/DDrawCompat.vcxproj.filters index 4a16f61..1578c88 100644 --- a/DDrawCompat/DDrawCompat.vcxproj.filters +++ b/DDrawCompat/DDrawCompat.vcxproj.filters @@ -137,9 +137,6 @@ Source Files - - Source Files - Source Files diff --git a/DDrawCompat/Hook.cpp b/DDrawCompat/Hook.cpp index 22e3484..bf5c371 100644 --- a/DDrawCompat/Hook.cpp +++ b/DDrawCompat/Hook.cpp @@ -1,7 +1,7 @@ #define WIN32_LEAN_AND_MEAN +#include #include -#include #include #include @@ -11,7 +11,13 @@ namespace { - std::vector> g_hookedFunctions; + struct HookedFunctionInfo + { + void* trampoline; + void* newFunction; + }; + + std::map g_hookedFunctions; FARPROC getProcAddress(HMODULE module, const char* procName) { @@ -53,6 +59,15 @@ namespace void hookFunction(const char* funcName, void*& origFuncPtr, void* newFuncPtr) { + const auto it = g_hookedFunctions.find(origFuncPtr); + if (it != g_hookedFunctions.end()) + { + origFuncPtr = it->second.trampoline; + return; + } + + void* const hookedFuncPtr = origFuncPtr; + DetourTransactionBegin(); const bool attachSuccessful = NO_ERROR == DetourAttach(&origFuncPtr, newFuncPtr); const bool commitSuccessful = NO_ERROR == DetourTransactionCommit(); @@ -69,7 +84,7 @@ namespace return; } - g_hookedFunctions.push_back(std::make_pair(origFuncPtr, newFuncPtr)); + g_hookedFunctions[hookedFuncPtr] = { origFuncPtr, newFuncPtr }; } } @@ -79,7 +94,7 @@ namespace Compat { ::hookFunction(nullptr, origFuncPtr, newFuncPtr); } - + void hookFunction(const char* moduleName, const char* funcName, void*& origFuncPtr, void* newFuncPtr) { FARPROC procAddr = getProcAddress(GetModuleHandle(moduleName), funcName); @@ -98,7 +113,7 @@ namespace Compat for (auto& hookedFunc : g_hookedFunctions) { DetourTransactionBegin(); - DetourDetach(&hookedFunc.first, hookedFunc.second); + DetourDetach(&hookedFunc.second.trampoline, hookedFunc.second.newFunction); DetourTransactionCommit(); } }