1
0
mirror of https://github.com/narzoul/DDrawCompat synced 2024-12-30 08:55:36 +01:00

Fixed unsafe use of released primary surface interface

This commit is contained in:
narzoul 2016-05-15 17:59:35 +02:00
parent 2717b095ec
commit 70a29c2f12
14 changed files with 291 additions and 51 deletions

View File

@ -3,6 +3,7 @@
#include "CompatDirectDrawSurface.h" #include "CompatDirectDrawSurface.h"
#include "CompatGdi.h" #include "CompatGdi.h"
#include "CompatPrimarySurface.h" #include "CompatPrimarySurface.h"
#include "CompatPtr.h"
#include "DDrawLog.h" #include "DDrawLog.h"
extern HWND g_mainWindow; extern HWND g_mainWindow;
@ -38,12 +39,10 @@ namespace
&dd, dm.width, dm.height, 32, dm.refreshRate, 0); &dd, dm.width, dm.height, 32, dm.refreshRate, 0);
} }
if (CompatPrimarySurface::surface) auto primary(CompatPrimarySurface::getPrimary());
if (primary && SUCCEEDED(primary->Restore(primary)))
{ {
CompatDirectDrawSurface<IDirectDrawSurface7>::s_origVtable.Restore( CompatDirectDrawSurface<IDirectDrawSurface7>::fixSurfacePtrs(*primary);
CompatPrimarySurface::surface);
CompatDirectDrawSurface<IDirectDrawSurface7>::fixSurfacePtrs(
*CompatPrimarySurface::surface);
CompatGdi::invalidate(nullptr); CompatGdi::invalidate(nullptr);
} }
} }

View File

@ -236,11 +236,8 @@ HRESULT CompatDirectDrawSurface<TSurface>::createCompatPrimarySurface(
return result; return result;
} }
IDirectDrawSurface7* compatSurface7 = nullptr; CompatPtr<IDirectDrawSurface7> primary(Compat::queryInterface<IDirectDrawSurface7>(compatSurface));
s_origVtable.QueryInterface(compatSurface, IID_IDirectDrawSurface7, CompatPrimarySurface::setPrimary(*primary);
reinterpret_cast<void**>(&compatSurface7));
CompatPrimarySurface::setPrimary(compatSurface7);
CompatDirectDrawSurface<IDirectDrawSurface7>::s_origVtable.Release(compatSurface7);
return DD_OK; return DD_OK;
} }

View File

@ -1,5 +1,7 @@
#pragma once #pragma once
#include <type_traits>
#include "CompatVtable.h" #include "CompatVtable.h"
#include "DDrawTypes.h" #include "DDrawTypes.h"
#include "DirectDrawSurfaceVtblVisitor.h" #include "DirectDrawSurfaceVtblVisitor.h"
@ -66,3 +68,40 @@ public:
private: private:
static void restorePrimaryCaps(TDdsCaps& caps); static void restorePrimaryCaps(TDdsCaps& caps);
}; };
namespace Compat
{
template <typename Intf>
struct IsDirectDrawSurfaceIntf : std::false_type {};
template<> struct IsDirectDrawSurfaceIntf<IDirectDrawSurface> : std::true_type {};
template<> struct IsDirectDrawSurfaceIntf<IDirectDrawSurface2> : std::true_type {};
template<> struct IsDirectDrawSurfaceIntf<IDirectDrawSurface3> : std::true_type {};
template<> struct IsDirectDrawSurfaceIntf<IDirectDrawSurface4> : std::true_type {};
template<> struct IsDirectDrawSurfaceIntf<IDirectDrawSurface7> : std::true_type {};
template <typename NewIntf, typename OrigIntf>
std::enable_if_t<IsDirectDrawSurfaceIntf<NewIntf>::value && IsDirectDrawSurfaceIntf<OrigIntf>::value>
queryInterface(OrigIntf& origIntf, NewIntf*& newIntf)
{
CompatDirectDrawSurface<OrigIntf>::s_origVtable.QueryInterface(
&origIntf, CompatDirectDrawSurface<NewIntf>::s_iid, reinterpret_cast<void**>(&newIntf));
}
template <typename NewIntf>
std::enable_if_t<IsDirectDrawSurfaceIntf<NewIntf>::value>
queryInterface(IUnknown& origIntf, NewIntf*& newIntf)
{
CompatDirectDrawSurface<IDirectDrawSurface>::s_origVtable.QueryInterface(
reinterpret_cast<IDirectDrawSurface*>(&origIntf),
CompatDirectDrawSurface<NewIntf>::s_iid, reinterpret_cast<void**>(&newIntf));
}
template <typename OrigIntf>
std::enable_if_t<IsDirectDrawSurfaceIntf<OrigIntf>::value>
queryInterface(OrigIntf& origIntf, IUnknown*& newIntf)
{
CompatDirectDrawSurface<OrigIntf>::s_origVtable.QueryInterface(
&origIntf, IID_IUnknown, reinterpret_cast<void**>(&newIntf));
}
}

View File

@ -59,8 +59,8 @@ namespace
{ {
DDSURFACEDESC2 desc = {}; DDSURFACEDESC2 desc = {};
desc.dwSize = sizeof(desc); desc.dwSize = sizeof(desc);
if (FAILED(CompatDirectDrawSurface<IDirectDrawSurface7>::s_origVtable.Lock( auto primary(CompatPrimarySurface::getPrimary());
CompatPrimarySurface::surface, nullptr, &desc, DDLOCK_WAIT, nullptr))) if (FAILED(primary->Lock(primary, nullptr, &desc, DDLOCK_WAIT, nullptr)))
{ {
return false; return false;
} }
@ -74,8 +74,8 @@ namespace
void unlockPrimarySurface() void unlockPrimarySurface()
{ {
GdiFlush(); GdiFlush();
CompatDirectDrawSurface<IDirectDrawSurface7>::s_origVtable.Unlock( auto primary(CompatPrimarySurface::getPrimary());
CompatPrimarySurface::surface, nullptr); primary->Unlock(primary, nullptr);
RealPrimarySurface::invalidate(nullptr); RealPrimarySurface::invalidate(nullptr);
RealPrimarySurface::update(); RealPrimarySurface::update();

View File

@ -4,28 +4,21 @@
#include "CompatDirectDraw.h" #include "CompatDirectDraw.h"
#include "CompatDirectDrawSurface.h" #include "CompatDirectDrawSurface.h"
#include "CompatPrimarySurface.h" #include "CompatPrimarySurface.h"
#include "CompatPtr.h"
#include "IReleaseNotifier.h" #include "IReleaseNotifier.h"
#include "RealPrimarySurface.h" #include "RealPrimarySurface.h"
namespace namespace
{ {
CompatWeakPtr<IDirectDrawSurface> g_primarySurface = nullptr;
std::vector<void*> g_primarySurfacePtrs; std::vector<void*> g_primarySurfacePtrs;
void addPrimary(IDirectDrawSurface7* surface, const IID& iid)
{
IUnknown* intf = nullptr;
CompatDirectDrawSurface<IDirectDrawSurface7>::s_origVtable.QueryInterface(
surface, iid, reinterpret_cast<void**>(&intf));
g_primarySurfacePtrs.push_back(intf);
intf->lpVtbl->Release(intf);
}
void onRelease() void onRelease()
{ {
Compat::LogEnter("CompatPrimarySurface::onRelease"); Compat::LogEnter("CompatPrimarySurface::onRelease");
g_primarySurfacePtrs.clear(); g_primarySurfacePtrs.clear();
CompatPrimarySurface::surface = nullptr; g_primarySurface = nullptr;
CompatPrimarySurface::palette = nullptr; CompatPrimarySurface::palette = nullptr;
CompatPrimarySurface::width = 0; CompatPrimarySurface::width = 0;
CompatPrimarySurface::height = 0; CompatPrimarySurface::height = 0;
@ -59,32 +52,41 @@ namespace CompatPrimarySurface
template DisplayMode getDisplayMode(IDirectDraw4& dd); template DisplayMode getDisplayMode(IDirectDraw4& dd);
template DisplayMode getDisplayMode(IDirectDraw7& dd); template DisplayMode getDisplayMode(IDirectDraw7& dd);
bool isPrimary(void* surfacePtr) CompatPtr<IDirectDrawSurface7> getPrimary()
{ {
return g_primarySurfacePtrs.end() != if (!g_primarySurface)
std::find(g_primarySurfacePtrs.begin(), g_primarySurfacePtrs.end(), surfacePtr); {
return nullptr;
}
return CompatPtr<IDirectDrawSurface7>(
Compat::queryInterface<IDirectDrawSurface7>(g_primarySurface.get()));
} }
void setPrimary(IDirectDrawSurface7* surfacePtr) bool isPrimary(void* surface)
{ {
surface = surfacePtr; return g_primarySurfacePtrs.end() !=
std::find(g_primarySurfacePtrs.begin(), g_primarySurfacePtrs.end(), surface);
}
void setPrimary(CompatRef<IDirectDrawSurface7> surface)
{
CompatPtr<IDirectDrawSurface> surfacePtr(Compat::queryInterface<IDirectDrawSurface>(&surface));
g_primarySurface = surfacePtr;
g_primarySurfacePtrs.clear(); g_primarySurfacePtrs.clear();
g_primarySurfacePtrs.push_back(&surface);
g_primarySurfacePtrs.push_back(CompatPtr<IDirectDrawSurface4>(surfacePtr));
g_primarySurfacePtrs.push_back(CompatPtr<IDirectDrawSurface3>(surfacePtr));
g_primarySurfacePtrs.push_back(CompatPtr<IDirectDrawSurface2>(surfacePtr));
g_primarySurfacePtrs.push_back(surfacePtr); g_primarySurfacePtrs.push_back(surfacePtr);
addPrimary(surfacePtr, IID_IDirectDrawSurface4);
addPrimary(surfacePtr, IID_IDirectDrawSurface3);
addPrimary(surfacePtr, IID_IDirectDrawSurface2);
addPrimary(surfacePtr, IID_IDirectDrawSurface);
IReleaseNotifier* releaseNotifierPtr = &releaseNotifier; IReleaseNotifier* releaseNotifierPtr = &releaseNotifier;
CompatDirectDrawSurface<IDirectDrawSurface7>::s_origVtable.SetPrivateData( surface->SetPrivateData(&surface, IID_IReleaseNotifier,
surfacePtr, IID_IReleaseNotifier, releaseNotifierPtr, sizeof(releaseNotifierPtr), releaseNotifierPtr, sizeof(releaseNotifierPtr), DDSPD_IUNKNOWNPOINTER);
DDSPD_IUNKNOWNPOINTER);
} }
DisplayMode displayMode = {}; DisplayMode displayMode = {};
bool isDisplayModeChanged = false; bool isDisplayModeChanged = false;
IDirectDrawSurface7* surface = nullptr;
LPDIRECTDRAWPALETTE palette = nullptr; LPDIRECTDRAWPALETTE palette = nullptr;
PALETTEENTRY paletteEntries[256] = {}; PALETTEENTRY paletteEntries[256] = {};
LONG width = 0; LONG width = 0;

View File

@ -4,6 +4,9 @@
#include <ddraw.h> #include <ddraw.h>
#include "CompatPtr.h"
#include "CompatRef.h"
class IReleaseNotifier; class IReleaseNotifier;
namespace CompatPrimarySurface namespace CompatPrimarySurface
@ -19,12 +22,12 @@ namespace CompatPrimarySurface
template <typename TDirectDraw> template <typename TDirectDraw>
DisplayMode getDisplayMode(TDirectDraw& dd); DisplayMode getDisplayMode(TDirectDraw& dd);
bool isPrimary(void* surfacePtr); CompatPtr<IDirectDrawSurface7> getPrimary();
void setPrimary(IDirectDrawSurface7* surfacePtr); bool isPrimary(void* surface);
void setPrimary(CompatRef<IDirectDrawSurface7> surface);
extern DisplayMode displayMode; extern DisplayMode displayMode;
extern bool isDisplayModeChanged; extern bool isDisplayModeChanged;
extern IDirectDrawSurface7* surface;
extern LPDIRECTDRAWPALETTE palette; extern LPDIRECTDRAWPALETTE palette;
extern PALETTEENTRY paletteEntries[256]; extern PALETTEENTRY paletteEntries[256];
extern LONG width; extern LONG width;

58
DDrawCompat/CompatPtr.h Normal file
View File

@ -0,0 +1,58 @@
#pragma once
#include <algorithm>
#include "CompatQueryInterface.h"
#include "CompatWeakPtr.h"
template <typename Intf>
class CompatPtr : public CompatWeakPtr<Intf>
{
public:
CompatPtr(std::nullptr_t = nullptr)
{
}
explicit CompatPtr(Intf* intf) : CompatWeakPtr(intf)
{
}
CompatPtr(const CompatPtr& other)
{
m_intf = Compat::queryInterface<Intf>(other.get());
}
template <typename OtherIntf>
CompatPtr(const CompatPtr<OtherIntf>& other)
{
m_intf = Compat::queryInterface<Intf>(other.get());
}
~CompatPtr()
{
release();
}
CompatPtr& operator=(CompatPtr rhs)
{
swap(rhs);
return *this;
}
Intf* detach()
{
Intf* intf = m_intf;
m_intf = nullptr;
return intf;
}
void reset(Intf* intf = nullptr)
{
*this = CompatPtr(intf);
}
void swap(CompatPtr& other)
{
std::swap(m_intf, other.m_intf);
}
};

View File

@ -0,0 +1,28 @@
#pragma once
struct IUnknown;
namespace Compat
{
template <typename Intf>
void queryInterface(Intf& origIntf, Intf*& newIntf)
{
newIntf = &origIntf;
newIntf->lpVtbl->AddRef(newIntf);
}
void queryInterface(IUnknown&, IUnknown*&) = delete;
template <typename NewIntf, typename OrigIntf>
NewIntf* queryInterface(OrigIntf* origIntf)
{
if (!origIntf)
{
return nullptr;
}
NewIntf* newIntf = nullptr;
queryInterface(*origIntf, newIntf);
return newIntf;
}
}

30
DDrawCompat/CompatRef.h Normal file
View File

@ -0,0 +1,30 @@
#pragma once
#include "CompatVtable.h"
template <typename Intf>
class CompatRef
{
public:
CompatRef(Intf& intf) : m_intf(intf)
{
}
const Vtable<Intf>* operator->() const
{
return &CompatVtableBase<Intf>::getOrigVtable(m_intf);
}
Intf* operator&() const
{
return &m_intf;
}
Intf& get() const
{
return m_intf;
}
private:
Intf& m_intf;
};

View File

@ -11,12 +11,24 @@
template <typename Interface> template <typename Interface>
using Vtable = typename std::remove_pointer<decltype(Interface::lpVtbl)>::type; using Vtable = typename std::remove_pointer<decltype(Interface::lpVtbl)>::type;
template <typename CompatInterface, typename Interface> template <typename Interface>
class CompatVtable class CompatVtableBase
{ {
public: public:
typedef Interface Interface; typedef Interface Interface;
static const Vtable<Interface>& getOrigVtable(Interface& intf)
{
return s_origVtable.AddRef ? s_origVtable : *intf.lpVtbl;
}
static Vtable<Interface> s_origVtable;
};
template <typename CompatInterface, typename Interface>
class CompatVtable : public CompatVtableBase<Interface>
{
public:
static void hookVtable(Interface& intf) static void hookVtable(Interface& intf)
{ {
static bool isInitialized = false; static bool isInitialized = false;
@ -32,8 +44,6 @@ public:
} }
} }
static Vtable<Interface> s_origVtable;
private: private:
class InitVisitor class InitVisitor
{ {
@ -122,11 +132,11 @@ private:
static std::map<std::vector<unsigned char>, std::string> s_funcNames; static std::map<std::vector<unsigned char>, std::string> s_funcNames;
}; };
template <typename CompatInterface, typename Interface> template <typename Interface>
Vtable<Interface>* CompatVtable<CompatInterface, Interface>::s_vtablePtr = nullptr; Vtable<Interface> CompatVtableBase<Interface>::s_origVtable = {};
template <typename CompatInterface, typename Interface> template <typename CompatInterface, typename Interface>
Vtable<Interface> CompatVtable<CompatInterface, Interface>::s_origVtable = {}; Vtable<Interface>* CompatVtable<CompatInterface, Interface>::s_vtablePtr = nullptr;
template <typename CompatInterface, typename Interface> template <typename CompatInterface, typename Interface>
Vtable<Interface> CompatVtable<CompatInterface, Interface>::s_compatVtable(CompatInterface::getCompatVtable()); Vtable<Interface> CompatVtable<CompatInterface, Interface>::s_compatVtable(CompatInterface::getCompatVtable());

View File

@ -0,0 +1,54 @@
#pragma once
#include "CompatVtable.h"
template <typename Intf>
class CompatWeakPtr
{
public:
CompatWeakPtr(Intf* intf = nullptr) : m_intf(intf)
{
}
Intf& operator*() const
{
return *m_intf;
}
const Vtable<Intf>* operator->() const
{
return &CompatVtableBase<Intf>::getOrigVtable(*m_intf);
}
operator Intf*() const
{
return m_intf;
}
Intf* get() const
{
return m_intf;
}
Intf* const& getRef() const
{
return m_intf;
}
Intf*& getRef()
{
return m_intf;
}
void release()
{
if (m_intf)
{
m_intf->lpVtbl->Release(m_intf);
m_intf = nullptr;
}
}
protected:
Intf* m_intf;
};

View File

@ -157,7 +157,11 @@
<ClInclude Include="CompatGdiTitleBar.h" /> <ClInclude Include="CompatGdiTitleBar.h" />
<ClInclude Include="CompatGdiWinProc.h" /> <ClInclude Include="CompatGdiWinProc.h" />
<ClInclude Include="CompatPaletteConverter.h" /> <ClInclude Include="CompatPaletteConverter.h" />
<ClInclude Include="CompatPtr.h" />
<ClInclude Include="CompatQueryInterface.h" />
<ClInclude Include="CompatRef.h" />
<ClInclude Include="CompatRegistry.h" /> <ClInclude Include="CompatRegistry.h" />
<ClInclude Include="CompatWeakPtr.h" />
<ClInclude Include="Config.h" /> <ClInclude Include="Config.h" />
<ClInclude Include="DDrawProcs.h" /> <ClInclude Include="DDrawProcs.h" />
<ClInclude Include="CompatDirectDraw.h" /> <ClInclude Include="CompatDirectDraw.h" />

View File

@ -114,6 +114,18 @@
<ClInclude Include="CompatActivateAppHandler.h"> <ClInclude Include="CompatActivateAppHandler.h">
<Filter>Header Files</Filter> <Filter>Header Files</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="CompatPtr.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="CompatQueryInterface.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="CompatWeakPtr.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="CompatRef.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="DllMain.cpp"> <ClCompile Include="DllMain.cpp">

View File

@ -56,10 +56,11 @@ namespace
clipper->lpVtbl->Release(clipper); clipper->lpVtbl->Release(clipper);
} }
auto primary(CompatPrimarySurface::getPrimary());
if (CompatPrimarySurface::pixelFormat.dwRGBBitCount <= 8) if (CompatPrimarySurface::pixelFormat.dwRGBBitCount <= 8)
{ {
origVtable.Blt(CompatPaletteConverter::getSurface(), &g_updateRect, origVtable.Blt(CompatPaletteConverter::getSurface(), &g_updateRect,
CompatPrimarySurface::surface, &g_updateRect, DDBLT_WAIT, nullptr); primary, &g_updateRect, DDBLT_WAIT, nullptr);
HDC destDc = nullptr; HDC destDc = nullptr;
origVtable.GetDC(dest, &destDc); origVtable.GetDC(dest, &destDc);
@ -79,7 +80,7 @@ namespace
else else
{ {
result = SUCCEEDED(origVtable.Blt(dest, &g_updateRect, result = SUCCEEDED(origVtable.Blt(dest, &g_updateRect,
CompatPrimarySurface::surface, &g_updateRect, DDBLT_WAIT, nullptr)); primary, &g_updateRect, DDBLT_WAIT, nullptr));
} }
if (result) if (result)
@ -394,6 +395,9 @@ void RealPrimarySurface::updatePalette(DWORD startingEntry, DWORD count)
{ {
CompatPaletteConverter::updatePalette(startingEntry, count); CompatPaletteConverter::updatePalette(startingEntry, count);
CompatGdi::updatePalette(startingEntry, count); CompatGdi::updatePalette(startingEntry, count);
invalidate(nullptr); if (CompatPrimarySurface::palette)
update(); {
invalidate(nullptr);
update();
}
} }