diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index 84698c4c..b99b7c00 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -60,6 +60,8 @@ namespace dxvk { case DxbcInstClass::ControlFlow: return this->handleControlFlow (parsedInst); case DxbcInstClass::TextureSample: return this->handleTextureSample(parsedInst); case DxbcInstClass::VectorAlu: return this->handleVectorAlu (parsedInst); + case DxbcInstClass::VectorCmov: return this->handleVectorCmov (parsedInst); + case DxbcInstClass::VectorCmp: return this->handleVectorCmp (parsedInst); case DxbcInstClass::VectorDot: return this->handleVectorDot (parsedInst); case DxbcInstClass::VectorSinCos: return this->handleVectorSinCos (parsedInst); default: return DxbcError::eUnhandledOpcode; @@ -669,6 +671,13 @@ namespace dxvk { arguments[1].valueId); break; + case DxbcOpcode::Div: + result.valueId = m_module.opFDiv( + resultTypeId, + arguments[0].valueId, + arguments[1].valueId); + break; + case DxbcOpcode::Mad: result.valueId = m_module.opFFma( resultTypeId, @@ -718,6 +727,137 @@ namespace dxvk { } + DxbcError DxbcCompiler::handleVectorCmov(const DxbcInst& ins) { + // movc has four operands: + // (1) The destination register + // (2) The condition vector + // (3) Vector to select from if the condition is not 0 + // (4) Vector to select from if the condition is 0 + const DxbcValue condition = this->loadOp(ins.operands[1], ins.operands[0].mask, ins.format.operands[1].type); + const DxbcValue selectTrue = this->loadOp(ins.operands[2], ins.operands[0].mask, ins.format.operands[2].type); + const DxbcValue selectFalse = this->loadOp(ins.operands[3], ins.operands[0].mask, ins.format.operands[3].type); + + const uint32_t componentCount = ins.operands[0].mask.setCount(); + + // We'll compare against a vector of zeroes to generate a + // boolean vector, which in turn will be used by OpSelect + uint32_t zeroType = m_module.defIntType(32, 0); + uint32_t boolType = m_module.defBoolType(); + + uint32_t zero = m_module.constu32(0); + + if (componentCount > 1) { + zeroType = m_module.defVectorType(zeroType, componentCount); + boolType = m_module.defVectorType(boolType, componentCount); + + const std::array zeroVec = { zero, zero, zero, zero }; + zero = m_module.constComposite(zeroType, componentCount, zeroVec.data()); + } + + // Use the component mask to select the vector components + DxbcValue result; + result.componentType = ins.format.operands[0].type; + result.componentCount = componentCount; + result.valueId = m_module.opSelect( + this->defineVectorType(result.componentType, result.componentCount), + m_module.opINotEqual(boolType, condition.valueId, zero), + selectTrue.valueId, selectFalse.valueId); + + // Apply result modifiers to floating-point results + result = this->applyResultModifiers(result, ins.control); + this->storeOp(ins.operands[0], result); + return DxbcError::sOk; + } + + + DxbcError DxbcCompiler::handleVectorCmp(const DxbcInst& ins) { + // Compare instructions have three operands: + // (1) The destination register + // (2) The first vector to compare + // (3) The second vector to compare + DxbcValue arguments[2]; + + if (ins.format.operandCount != 3) + return DxbcError::eInvalidOperand; + + for (uint32_t i = 0; i < 2; i++) { + arguments[i] = this->loadOp( + ins.operands[i + 1], + ins.operands[0].mask, + ins.format.operands[i + 1].type); + } + + const uint32_t componentCount = ins.operands[0].mask.setCount(); + + // Condition, which is a boolean vector used + // to select between the ~0u and 0u vectors. + uint32_t condition = 0; + uint32_t conditionType = m_module.defBoolType(); + + if (componentCount > 1) + conditionType = m_module.defVectorType(conditionType, componentCount); + + switch (ins.opcode) { + case DxbcOpcode::Eq: + condition = m_module.opFOrdEqual( + conditionType, + arguments[0].valueId, + arguments[1].valueId); + break; + + case DxbcOpcode::Ge: + condition = m_module.opFOrdGreaterThanEqual( + conditionType, + arguments[0].valueId, + arguments[1].valueId); + break; + + case DxbcOpcode::Lt: + condition = m_module.opFOrdLessThan( + conditionType, + arguments[0].valueId, + arguments[1].valueId); + break; + + case DxbcOpcode::Ne: + condition = m_module.opFOrdNotEqual( + conditionType, + arguments[0].valueId, + arguments[1].valueId); + break; + + default: + return DxbcError::eUnhandledOpcode; + } + + // Generate constant vectors for selection + uint32_t sFalse = m_module.constu32( 0u); + uint32_t sTrue = m_module.constu32(~0u); + + const uint32_t maskType = this->defineVectorType( + DxbcScalarType::Uint32, componentCount); + + if (componentCount > 1) { + const std::array vFalse = { sFalse, sFalse, sFalse, sFalse }; + const std::array vTrue = { sTrue, sTrue, sTrue, sTrue }; + + sFalse = m_module.constComposite(maskType, componentCount, vFalse.data()); + sTrue = m_module.constComposite(maskType, componentCount, vTrue .data()); + } + + // Perform component-wise mask selection + // based on the condition evaluated above. + DxbcValue result; + result.componentType = DxbcScalarType::Uint32; + result.componentCount = componentCount; + result.valueId = m_module.opSelect( + maskType, condition, sTrue, sFalse); + + this->storeOp(ins.operands[0], result); + return DxbcError::sOk; + } + + DxbcError DxbcCompiler::handleVectorDot(const DxbcInst& ins) { // Determine the component count and the source // operand mask. Since the result is scalar, we diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index 137b89ad..57870b80 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -287,6 +287,12 @@ namespace dxvk { DxbcError handleVectorAlu( const DxbcInst& ins); + DxbcError handleVectorCmov( + const DxbcInst& ins); + + DxbcError handleVectorCmp( + const DxbcInst& ins); + DxbcError handleVectorDot( const DxbcInst& ins); diff --git a/src/dxbc/dxbc_defs.cpp b/src/dxbc/dxbc_defs.cpp index 3e958277..b49391a0 100644 --- a/src/dxbc/dxbc_defs.cpp +++ b/src/dxbc/dxbc_defs.cpp @@ -36,7 +36,11 @@ namespace dxvk { /* Discard */ { }, /* Div */ - { }, + { 3, DxbcInstClass::VectorAlu, { + { DxbcOperandKind::DstReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + } }, /* Dp2 */ { 3, DxbcInstClass::VectorDot, { { DxbcOperandKind::DstReg, DxbcScalarType::Float32 }, @@ -68,7 +72,11 @@ namespace dxvk { /* EndSwitch */ { }, /* Eq */ - { }, + { 3, DxbcInstClass::VectorCmp, { + { DxbcOperandKind::DstReg, DxbcScalarType::Uint32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + } }, /* Exp */ { }, /* Frc */ @@ -78,7 +86,11 @@ namespace dxvk { /* FtoU */ { }, /* Ge */ - { }, + { 3, DxbcInstClass::VectorCmp, { + { DxbcOperandKind::DstReg, DxbcScalarType::Uint32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + } }, /* IAdd */ { }, /* If */ @@ -118,7 +130,11 @@ namespace dxvk { /* Loop */ { }, /* Lt */ - { }, + { 3, DxbcInstClass::VectorCmp, { + { DxbcOperandKind::DstReg, DxbcScalarType::Uint32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + } }, /* Mad */ { 4, DxbcInstClass::VectorAlu, { { DxbcOperandKind::DstReg, DxbcScalarType::Float32 }, @@ -146,7 +162,12 @@ namespace dxvk { { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, } }, /* Movc */ - { }, + { 4, DxbcInstClass::VectorCmov, { + { DxbcOperandKind::DstReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Uint32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + } }, /* Mul */ { 3, DxbcInstClass::VectorAlu, { { DxbcOperandKind::DstReg, DxbcScalarType::Float32 }, @@ -154,7 +175,11 @@ namespace dxvk { { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, } }, /* Ne */ - { }, + { 3, DxbcInstClass::VectorCmp, { + { DxbcOperandKind::DstReg, DxbcScalarType::Uint32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + { DxbcOperandKind::SrcReg, DxbcScalarType::Float32 }, + } }, /* Nop */ { }, /* Not */ diff --git a/src/dxbc/dxbc_defs.h b/src/dxbc/dxbc_defs.h index 4c66fd4e..44367e7c 100644 --- a/src/dxbc/dxbc_defs.h +++ b/src/dxbc/dxbc_defs.h @@ -31,6 +31,7 @@ namespace dxvk { Declaration, ///< Interface or resource declaration TextureSample, ///< Texture sampling instruction VectorAlu, ///< Component-wise vector instructions + VectorCmov, ///< Component-wise conditional move VectorCmp, ///< Component-wise vector comparison VectorDot, ///< Dot product instruction VectorSinCos, ///< Sine and Cosine instruction diff --git a/src/spirv/spirv_module.cpp b/src/spirv/spirv_module.cpp index 407f9a92..2b9cdc37 100644 --- a/src/spirv/spirv_module.cpp +++ b/src/spirv/spirv_module.cpp @@ -738,6 +738,21 @@ namespace dxvk { } + uint32_t SpirvModule::opFDiv( + uint32_t resultType, + uint32_t a, + uint32_t b) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFDiv, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(a); + m_code.putWord(b); + return resultId; + } + + uint32_t SpirvModule::opIMul( uint32_t resultType, uint32_t a, @@ -840,6 +855,126 @@ namespace dxvk { } + uint32_t SpirvModule::opIEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpIEqual, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opINotEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpINotEqual, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opFOrdEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFOrdEqual, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opFOrdNotEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFOrdNotEqual, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opFOrdLessThan( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFOrdLessThan, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opFOrdLessThanEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFOrdLessThanEqual, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opFOrdGreaterThan( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFOrdGreaterThan, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + + uint32_t SpirvModule::opFOrdGreaterThanEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpFOrdGreaterThanEqual, 5); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(vector1); + m_code.putWord(vector2); + return resultId; + } + + uint32_t SpirvModule::opDot( uint32_t resultType, uint32_t vector1, @@ -900,6 +1035,23 @@ namespace dxvk { } + uint32_t SpirvModule::opSelect( + uint32_t resultType, + uint32_t condition, + uint32_t operand1, + uint32_t operand2) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpSelect, 6); + m_code.putWord(resultType); + m_code.putWord(resultId); + m_code.putWord(condition); + m_code.putWord(operand1); + m_code.putWord(operand2); + return resultId; + } + + uint32_t SpirvModule::opFunctionCall( uint32_t resultType, uint32_t functionId, diff --git a/src/spirv/spirv_module.h b/src/spirv/spirv_module.h index c20ba9aa..c3e6537c 100644 --- a/src/spirv/spirv_module.h +++ b/src/spirv/spirv_module.h @@ -265,6 +265,11 @@ namespace dxvk { uint32_t a, uint32_t b); + uint32_t opFDiv( + uint32_t resultType, + uint32_t a, + uint32_t b); + uint32_t opIMul( uint32_t resultType, uint32_t a, @@ -297,6 +302,46 @@ namespace dxvk { uint32_t minVal, uint32_t maxVal); + uint32_t opIEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opINotEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opFOrdEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opFOrdNotEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opFOrdLessThan( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opFOrdLessThanEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opFOrdGreaterThan( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + + uint32_t opFOrdGreaterThanEqual( + uint32_t resultType, + uint32_t vector1, + uint32_t vector2); + uint32_t opDot( uint32_t resultType, uint32_t vector1, @@ -314,6 +359,12 @@ namespace dxvk { uint32_t resultType, uint32_t x); + uint32_t opSelect( + uint32_t resultType, + uint32_t condition, + uint32_t operand1, + uint32_t operand2); + uint32_t opFunctionCall( uint32_t resultType, uint32_t functionId,