From 621aed5fdbaf92764944be7b3a27cbb3df63ba94 Mon Sep 17 00:00:00 2001
From: Philip Rebohle <philip.rebohle@tu-dortmund.de>
Date: Thu, 31 May 2018 10:13:32 +0200
Subject: [PATCH] [dxbc] Bound-check dynamically indexed constant buffer reads

Emulates D3D11 behaviour more closely on Nvidia hardware.
Fixes an issue in Dark Souls Remastered caused by constant
buffer access with an undefined index value (#405).
---
 src/dxbc/dxbc_compiler.cpp | 97 +++++++++++++++++++++++++-------------
 src/dxbc/dxbc_compiler.h   | 10 ++--
 2 files changed, 71 insertions(+), 36 deletions(-)

diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp
index abeb1658..afcc55a9 100644
--- a/src/dxbc/dxbc_compiler.cpp
+++ b/src/dxbc/dxbc_compiler.cpp
@@ -4210,35 +4210,6 @@ namespace dxvk {
   }
   
   
-  DxbcRegisterPointer DxbcCompiler::emitGetConstBufPtr(
-    const DxbcRegister&           operand) {
-    // Constant buffers take a two-dimensional index:
-    //    (0) register index (immediate)
-    //    (1) constant offset (relative)
-    DxbcRegisterInfo info;
-    info.type.ctype   = DxbcScalarType::Float32;
-    info.type.ccount  = 4;
-    info.type.alength = 0;
-    info.sclass = spv::StorageClassUniform;
-    
-    const uint32_t regId = operand.idx[0].offset;
-    const DxbcRegisterValue constId = emitIndexLoad(operand.idx[1]);
-    
-    const uint32_t ptrTypeId = getPointerTypeId(info);
-    
-    const std::array<uint32_t, 2> indices =
-      {{ m_module.consti32(0), constId.id }};
-    
-    DxbcRegisterPointer result;
-    result.type.ctype  = info.type.ctype;
-    result.type.ccount = info.type.ccount;
-    result.id = m_module.opAccessChain(ptrTypeId,
-      m_constantBuffers.at(regId).varId,
-      indices.size(), indices.data());
-    return result;
-  }
-  
-  
   DxbcRegisterPointer DxbcCompiler::emitGetImmConstBufPtr(
     const DxbcRegister&           operand) {
     if (m_immConstBuf == 0)
@@ -4281,9 +4252,6 @@ namespace dxvk {
       case DxbcOperandType::Output:
         return emitGetOutputPtr(operand);
       
-      case DxbcOperandType::ConstantBuffer:
-        return emitGetConstBufPtr(operand);
-      
       case DxbcOperandType::ImmediateConstantBuffer:
         return emitGetImmConstBufPtr(operand);
       
@@ -4740,6 +4708,22 @@ namespace dxvk {
   }
   
   
+  DxbcRegisterValue DxbcCompiler::emitIndexBoundCheck(
+          DxbcRegisterValue       index,
+          DxbcRegisterValue       count) {
+    index = emitRegisterBitcast(index, DxbcScalarType::Uint32);
+    count = emitRegisterBitcast(count, DxbcScalarType::Uint32);
+
+    DxbcRegisterValue result;
+    result.type.ctype  = DxbcScalarType::Bool;
+    result.type.ccount = index.type.ccount;
+    result.id = m_module.opULessThan(
+      getVectorTypeId(result.type),
+      index.id, count.id);
+    return result;
+  }
+
+
   DxbcRegisterValue DxbcCompiler::emitIndexLoad(
           DxbcRegIndex            index) {
     if (index.relReg != nullptr) {
@@ -4802,9 +4786,56 @@ namespace dxvk {
   }
   
   
+  DxbcRegisterValue DxbcCompiler::emitConstBufLoadRaw(
+    const DxbcRegister&           operand) {
+    // Constant buffers take a two-dimensional index:
+    //    (0) register index (immediate)
+    //    (1) constant offset (relative)
+    DxbcRegisterInfo info;
+    info.type.ctype   = DxbcScalarType::Float32;
+    info.type.ccount  = 4;
+    info.type.alength = 0;
+    info.sclass = spv::StorageClassUniform;
+    
+    const uint32_t regId = operand.idx[0].offset;
+    const DxbcRegisterValue constId = emitIndexLoad(operand.idx[1]);
+    
+    const uint32_t ptrTypeId = getPointerTypeId(info);
+    
+    const std::array<uint32_t, 2> indices =
+      {{ m_module.consti32(0), constId.id }};
+    
+    DxbcRegisterPointer pointer;
+    pointer.type.ctype  = info.type.ctype;
+    pointer.type.ccount = info.type.ccount;
+    pointer.id = m_module.opAccessChain(ptrTypeId,
+      m_constantBuffers.at(regId).varId,
+      indices.size(), indices.data());
+    
+    DxbcRegisterValue value = emitValueLoad(pointer);
+
+    // For dynamically indexed constant buffers, we should
+    // return a vec4(ß.0f) if the index is out of bounds
+    if (operand.idx[1].relReg != nullptr) {
+      DxbcRegisterValue cbSize;
+      cbSize.type = { DxbcScalarType::Uint32, 1 };
+      cbSize.id   = m_module.constu32(m_constantBuffers.at(regId).size);
+      DxbcRegisterValue inBounds = emitRegisterExtend(emitIndexBoundCheck(constId, cbSize), 4);
+      
+      value.id = m_module.opSelect(
+        getVectorTypeId(value.type), inBounds.id, value.id,
+        m_module.constvec4f32(0.0f, 0.0f, 0.0f, 0.0f));
+    }
+
+    return value;
+  }
+
+
   DxbcRegisterValue DxbcCompiler::emitRegisterLoadRaw(
     const DxbcRegister&           reg) {
-    return emitValueLoad(emitGetOperandPtr(reg));
+    return reg.type == DxbcOperandType::ConstantBuffer
+      ? emitConstBufLoadRaw(reg)
+      : emitValueLoad(emitGetOperandPtr(reg));
   }
   
   
diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h
index 38115a08..a48ffea5 100644
--- a/src/dxbc/dxbc_compiler.h
+++ b/src/dxbc/dxbc_compiler.h
@@ -803,9 +803,6 @@ namespace dxvk {
     DxbcRegisterPointer emitGetOutputPtr(
       const DxbcRegister&           operand);
     
-    DxbcRegisterPointer emitGetConstBufPtr(
-      const DxbcRegister&           operand);
-    
     DxbcRegisterPointer emitGetImmConstBufPtr(
       const DxbcRegister&           operand);
     
@@ -863,6 +860,10 @@ namespace dxvk {
     
     //////////////////////////////
     // Operand load/store methods
+    DxbcRegisterValue emitIndexBoundCheck(
+            DxbcRegisterValue       index,
+            DxbcRegisterValue       count);
+
     DxbcRegisterValue emitIndexLoad(
             DxbcRegIndex            index);
     
@@ -874,6 +875,9 @@ namespace dxvk {
             DxbcRegisterValue       value,
             DxbcRegMask             writeMask);
     
+    DxbcRegisterValue emitConstBufLoadRaw(
+      const DxbcRegister&           operand);
+    
     DxbcRegisterValue emitRegisterLoadRaw(
       const DxbcRegister&           reg);