@@ -1516,10 +1516,28 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
1516
1516
1517
1517
for (const auto & info : createInfos)
1518
1518
{
1519
- core::unordered_map<const asset::IShader*, uint32_t > shaderIndexes;
1520
- auto getVkShaderIndex = [&](const asset::IShader* shader)
1519
+ struct VkShaderStageKey
1521
1520
{
1522
- const auto index = shader == nullptr ? VK_SHADER_UNUSED_KHR : shaderIndexes[shader];
1521
+ const asset::IShader* shader;
1522
+ std::string_view entryPoint;
1523
+ bool operator ==(const VkShaderStageKey& other) const = default ;
1524
+
1525
+ struct HashFunction
1526
+ {
1527
+ size_t operator ()(const VkShaderStageKey& key) const
1528
+ {
1529
+ size_t rowHash = std::hash<const asset::IShader*>()(key.shader );
1530
+ size_t colHash = std::hash<std::string_view>()(key.entryPoint ) << 1 ;
1531
+ return rowHash ^ colHash;
1532
+ }
1533
+ };
1534
+ };
1535
+
1536
+ core::unordered_map<VkShaderStageKey, uint32_t , VkShaderStageKey::HashFunction> shaderIndexes;
1537
+ auto getVkShaderIndex = [&](const IGPUPipelineBase::SShaderSpecInfo& spec)
1538
+ {
1539
+ const auto key = VkShaderStageKey{ spec.shader , spec.entryPoint };
1540
+ const auto index = key.shader == nullptr ? VK_SHADER_UNUSED_KHR : shaderIndexes[key];
1523
1541
return index;
1524
1542
};
1525
1543
@@ -1529,7 +1547,7 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
1529
1547
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
1530
1548
.pNext = nullptr ,
1531
1549
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
1532
- .generalShader = getVkShaderIndex (spec.shader ),
1550
+ .generalShader = getVkShaderIndex ({ spec.shader , spec. entryPoint } ),
1533
1551
.closestHitShader = VK_SHADER_UNUSED_KHR,
1534
1552
.anyHitShader = VK_SHADER_UNUSED_KHR,
1535
1553
.intersectionShader = VK_SHADER_UNUSED_KHR,
@@ -1543,9 +1561,9 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
1543
1561
.type = group.intersection .shader == nullptr ?
1544
1562
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR,
1545
1563
.generalShader = VK_SHADER_UNUSED_KHR,
1546
- .closestHitShader = getVkShaderIndex (group.closestHit . shader ),
1547
- .anyHitShader = getVkShaderIndex (group.anyHit . shader ),
1548
- .intersectionShader = getVkShaderIndex (group.intersection . shader ),
1564
+ .closestHitShader = getVkShaderIndex (group.closestHit ),
1565
+ .anyHitShader = getVkShaderIndex (group.anyHit ),
1566
+ .intersectionShader = getVkShaderIndex (group.intersection ),
1549
1567
};
1550
1568
};
1551
1569
@@ -1554,9 +1572,10 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
1554
1572
auto processSpecInfo = [&](const IGPUPipelineBase::SShaderSpecInfo& spec, hlsl::ShaderStage shaderStage)
1555
1573
{
1556
1574
if (!spec.shader ) return ;
1557
- if (shaderIndexes.find (spec.shader ) == shaderIndexes.end ())
1575
+ const auto key = VkShaderStageKey{ spec.shader , spec.entryPoint };
1576
+ if (shaderIndexes.find (key) == shaderIndexes.end ())
1558
1577
{
1559
- shaderIndexes.insert ({ spec. shader , std::distance<decltype (outCreateInfo->pStages )>(outCreateInfo->pStages , outShaderStage)});
1578
+ shaderIndexes.insert ({ key , std::distance<decltype (outCreateInfo->pStages )>(outCreateInfo->pStages , outShaderStage)});
1560
1579
*(outShaderStage) = getVkShaderStageCreateInfoFrom (spec, shaderStage, false , outShaderModule, outEntryPoints, outRequiredSubgroupSize, outSpecInfo,outSpecMapEntry,outSpecData);
1561
1580
outShaderStage++;
1562
1581
}
0 commit comments