diff --git a/src/dxvk/dxvk_graphics.cpp b/src/dxvk/dxvk_graphics.cpp index cadd72a0..0f22bd2f 100644 --- a/src/dxvk/dxvk_graphics.cpp +++ b/src/dxvk/dxvk_graphics.cpp @@ -180,18 +180,11 @@ namespace dxvk { VkSpecializationInfo specInfo = specData.getSpecInfo(); - DxvkShaderModuleCreateInfo moduleInfo; - moduleInfo.fsDualSrcBlend = state.omBlend[0].blendEnable() && ( - util::isDualSourceBlendFactor(state.omBlend[0].srcColorBlendFactor()) || - util::isDualSourceBlendFactor(state.omBlend[0].dstColorBlendFactor()) || - util::isDualSourceBlendFactor(state.omBlend[0].srcAlphaBlendFactor()) || - util::isDualSourceBlendFactor(state.omBlend[0].dstAlphaBlendFactor())); - - auto vsm = createShaderModule(m_shaders.vs, moduleInfo); - auto gsm = createShaderModule(m_shaders.gs, moduleInfo); - auto tcsm = createShaderModule(m_shaders.tcs, moduleInfo); - auto tesm = createShaderModule(m_shaders.tes, moduleInfo); - auto fsm = createShaderModule(m_shaders.fs, moduleInfo); + auto vsm = createShaderModule(m_shaders.vs, state); + auto tcsm = createShaderModule(m_shaders.tcs, state); + auto tesm = createShaderModule(m_shaders.tes, state); + auto gsm = createShaderModule(m_shaders.gs, state); + auto fsm = createShaderModule(m_shaders.fs, state); std::vector stages; if (vsm) stages.push_back(vsm.stageInfo(&specInfo)); @@ -433,25 +426,69 @@ namespace dxvk { DxvkShaderModule DxvkGraphicsPipeline::createShaderModule( const Rc& shader, - const DxvkShaderModuleCreateInfo& info) const { - return shader != nullptr - ? shader->createShaderModule(m_vkd, m_slotMapping, info) - : DxvkShaderModule(); + const DxvkGraphicsPipelineStateInfo& state) const { + if (shader == nullptr) + return DxvkShaderModule(); + + DxvkShaderModuleCreateInfo info; + + // Fix up fragment shader outputs for dual-source blending + if (shader->stage() == VK_SHADER_STAGE_FRAGMENT_BIT) { + info.fsDualSrcBlend = state.omBlend[0].blendEnable() && ( + util::isDualSourceBlendFactor(state.omBlend[0].srcColorBlendFactor()) || + util::isDualSourceBlendFactor(state.omBlend[0].dstColorBlendFactor()) || + util::isDualSourceBlendFactor(state.omBlend[0].srcAlphaBlendFactor()) || + util::isDualSourceBlendFactor(state.omBlend[0].dstAlphaBlendFactor())); + } + + // Deal with undefined shader inputs + uint32_t consumedInputs = shader->interfaceSlots().inputSlots; + uint32_t providedInputs = 0; + + if (shader->stage() == VK_SHADER_STAGE_VERTEX_BIT) { + for (uint32_t i = 0; i < state.il.attributeCount(); i++) + providedInputs |= 1u << state.ilAttributes[i].location(); + } else if (shader->stage() != VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) { + auto prevStage = getPrevStageShader(shader->stage()); + providedInputs = prevStage->interfaceSlots().outputSlots; + } else { + // Technically not correct, but this + // would need a lot of extra care + providedInputs = consumedInputs; + } + + info.undefinedInputs = (providedInputs & consumedInputs) ^ consumedInputs; + return shader->createShaderModule(m_vkd, m_slotMapping, info); + } + + + Rc DxvkGraphicsPipeline::getPrevStageShader(VkShaderStageFlagBits stage) const { + if (stage == VK_SHADER_STAGE_VERTEX_BIT) + return nullptr; + + if (stage == VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) + return m_shaders.tcs; + + Rc result = m_shaders.vs; + + if (stage == VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT) + return result; + + if (m_shaders.tes != nullptr) + result = m_shaders.tes; + + if (stage == VK_SHADER_STAGE_GEOMETRY_BIT) + return result; + + if (m_shaders.gs != nullptr) + result = m_shaders.gs; + + return result; } bool DxvkGraphicsPipeline::validatePipelineState( const DxvkGraphicsPipelineStateInfo& state) const { - // Validate vertex input - each input slot consumed by the - // vertex shader must be provided by the input layout. - uint32_t providedVertexInputs = 0; - - for (uint32_t i = 0; i < state.il.attributeCount(); i++) - providedVertexInputs |= 1u << state.ilAttributes[i].location(); - - if ((providedVertexInputs & m_vsIn) != m_vsIn) - return false; - // Tessellation shaders and patches must be used together bool hasPatches = state.ia.primitiveTopology() == VK_PRIMITIVE_TOPOLOGY_PATCH_LIST; diff --git a/src/dxvk/dxvk_graphics.h b/src/dxvk/dxvk_graphics.h index 538af310..4194599d 100644 --- a/src/dxvk/dxvk_graphics.h +++ b/src/dxvk/dxvk_graphics.h @@ -240,8 +240,11 @@ namespace dxvk { DxvkShaderModule createShaderModule( const Rc& shader, - const DxvkShaderModuleCreateInfo& info) const; + const DxvkGraphicsPipelineStateInfo& state) const; + Rc getPrevStageShader( + VkShaderStageFlagBits stage) const; + bool validatePipelineState( const DxvkGraphicsPipelineStateInfo& state) const; @@ -252,7 +255,7 @@ namespace dxvk { void logPipelineState( LogLevel level, const DxvkGraphicsPipelineStateInfo& state) const; - + }; } \ No newline at end of file