Skip to content

Commit 3767ede

Browse files
author
kevyuu
committed
Fix shader indexing logic in ray tracing pipeline creation
1 parent fc1983f commit 3767ede

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

src/nbl/video/CVulkanLogicalDevice.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,10 +1516,28 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
15161516

15171517
for (const auto& info : createInfos)
15181518
{
1519-
core::unordered_map<const asset::IShader*, uint32_t> shaderIndexes;
1520-
auto getVkShaderIndex = [&](const asset::IShader* shader)
1519+
struct VkShaderStageKey
15211520
{
1522-
const auto index = shader == nullptr ? VK_SHADER_UNUSED_KHR : shaderIndexes[shader];
1521+
const asset::IShader* shader;
1522+
std::string_view entryPoint;
1523+
bool operator==(const VkShaderStageKey& other) const = default;
1524+
1525+
struct HashFunction
1526+
{
1527+
size_t operator()(const VkShaderStageKey& key) const
1528+
{
1529+
size_t rowHash = std::hash<const asset::IShader*>()(key.shader);
1530+
size_t colHash = std::hash<std::string_view>()(key.entryPoint) << 1;
1531+
return rowHash ^ colHash;
1532+
}
1533+
};
1534+
};
1535+
1536+
core::unordered_map<VkShaderStageKey, uint32_t, VkShaderStageKey::HashFunction> shaderIndexes;
1537+
auto getVkShaderIndex = [&](const IGPUPipelineBase::SShaderSpecInfo& spec)
1538+
{
1539+
const auto key = VkShaderStageKey{ spec.shader, spec.entryPoint };
1540+
const auto index = key.shader == nullptr ? VK_SHADER_UNUSED_KHR : shaderIndexes[key];
15231541
return index;
15241542
};
15251543

@@ -1529,7 +1547,7 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
15291547
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
15301548
.pNext = nullptr,
15311549
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
1532-
.generalShader = getVkShaderIndex(spec.shader),
1550+
.generalShader = getVkShaderIndex({spec.shader, spec.entryPoint}),
15331551
.closestHitShader = VK_SHADER_UNUSED_KHR,
15341552
.anyHitShader = VK_SHADER_UNUSED_KHR,
15351553
.intersectionShader = VK_SHADER_UNUSED_KHR,
@@ -1543,9 +1561,9 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
15431561
.type = group.intersection.shader == nullptr ?
15441562
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR,
15451563
.generalShader = VK_SHADER_UNUSED_KHR,
1546-
.closestHitShader = getVkShaderIndex(group.closestHit.shader),
1547-
.anyHitShader = getVkShaderIndex(group.anyHit.shader),
1548-
.intersectionShader = getVkShaderIndex(group.intersection.shader),
1564+
.closestHitShader = getVkShaderIndex(group.closestHit),
1565+
.anyHitShader = getVkShaderIndex(group.anyHit),
1566+
.intersectionShader = getVkShaderIndex(group.intersection),
15491567
};
15501568
};
15511569

@@ -1554,9 +1572,10 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
15541572
auto processSpecInfo = [&](const IGPUPipelineBase::SShaderSpecInfo& spec, hlsl::ShaderStage shaderStage)
15551573
{
15561574
if (!spec.shader) return;
1557-
if (shaderIndexes.find(spec.shader) == shaderIndexes.end())
1575+
const auto key = VkShaderStageKey{ spec.shader, spec.entryPoint };
1576+
if (shaderIndexes.find(key) == shaderIndexes.end())
15581577
{
1559-
shaderIndexes.insert({ spec.shader, std::distance<decltype(outCreateInfo->pStages)>(outCreateInfo->pStages, outShaderStage)});
1578+
shaderIndexes.insert({ key , std::distance<decltype(outCreateInfo->pStages)>(outCreateInfo->pStages, outShaderStage)});
15601579
*(outShaderStage) = getVkShaderStageCreateInfoFrom(spec, shaderStage, false, outShaderModule, outEntryPoints, outRequiredSubgroupSize, outSpecInfo,outSpecMapEntry,outSpecData);
15611580
outShaderStage++;
15621581
}

0 commit comments

Comments
 (0)