Skip to content

Commit 061d49c

Browse files
author
kevyuu
committed
Fix maxShaderStages calculation when creating ray tracing pipeline
1 parent 3767ede commit 061d49c

File tree

1 file changed

+39
-20
lines changed

1 file changed

+39
-20
lines changed

src/nbl/video/CVulkanLogicalDevice.cpp

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,9 +1486,44 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
14861486

14871487
const VkPipelineCache vk_pipelineCache = pipelineCache ? static_cast<const CVulkanPipelineCache*>(pipelineCache)->getInternalObject():VK_NULL_HANDLE;
14881488

1489+
struct ShaderModuleKey
1490+
{
1491+
const asset::IShader* shader;
1492+
std::string_view entryPoint;
1493+
bool operator==(const ShaderModuleKey& other) const = default;
1494+
1495+
struct HashFunction
1496+
{
1497+
size_t operator()(const ShaderModuleKey& key) const
1498+
{
1499+
size_t rowHash = std::hash<const asset::IShader*>()(key.shader);
1500+
size_t colHash = std::hash<std::string_view>()(key.entryPoint) << 1;
1501+
return rowHash ^ colHash;
1502+
}
1503+
};
1504+
};
14891505
size_t maxShaderStages = 0;
14901506
for (const auto& info : createInfos)
1491-
maxShaderStages += info.shaderGroups.getShaderCount();
1507+
{
1508+
core::unordered_set<ShaderModuleKey, ShaderModuleKey::HashFunction> shaderModules;
1509+
shaderModules.insert({ info.shaderGroups.raygen.shader, info.shaderGroups.raygen.entryPoint });
1510+
for (const auto& miss : info.shaderGroups.misses)
1511+
{
1512+
shaderModules.insert({ miss.shader, miss.entryPoint });
1513+
}
1514+
for (const auto& hit : info.shaderGroups.hits)
1515+
{
1516+
shaderModules.insert({ hit.closestHit.shader, hit.closestHit.entryPoint });
1517+
shaderModules.insert({ hit.anyHit.shader, hit.anyHit.entryPoint });
1518+
shaderModules.insert({ hit.intersection.shader, hit.intersection.entryPoint });
1519+
}
1520+
for (const auto& callable : info.shaderGroups.callables)
1521+
{
1522+
shaderModules.insert({ callable.shader, callable.entryPoint });
1523+
}
1524+
1525+
maxShaderStages += shaderModules.size();
1526+
}
14921527
size_t maxShaderGroups = 0;
14931528
for (const auto& info : createInfos)
14941529
maxShaderGroups += info.shaderGroups.getShaderGroupCount();
@@ -1516,27 +1551,11 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
15161551

15171552
for (const auto& info : createInfos)
15181553
{
1519-
struct VkShaderStageKey
1520-
{
1521-
const asset::IShader* shader;
1522-
std::string_view entryPoint;
1523-
bool operator==(const VkShaderStageKey& other) const = default;
1524-
1525-
struct HashFunction
1526-
{
1527-
size_t operator()(const VkShaderStageKey& key) const
1528-
{
1529-
size_t rowHash = std::hash<const asset::IShader*>()(key.shader);
1530-
size_t colHash = std::hash<std::string_view>()(key.entryPoint) << 1;
1531-
return rowHash ^ colHash;
1532-
}
1533-
};
1534-
};
15351554

1536-
core::unordered_map<VkShaderStageKey, uint32_t, VkShaderStageKey::HashFunction> shaderIndexes;
1555+
core::unordered_map<ShaderModuleKey, uint32_t, ShaderModuleKey::HashFunction> shaderIndexes;
15371556
auto getVkShaderIndex = [&](const IGPUPipelineBase::SShaderSpecInfo& spec)
15381557
{
1539-
const auto key = VkShaderStageKey{ spec.shader, spec.entryPoint };
1558+
const auto key = ShaderModuleKey{ spec.shader, spec.entryPoint };
15401559
const auto index = key.shader == nullptr ? VK_SHADER_UNUSED_KHR : shaderIndexes[key];
15411560
return index;
15421561
};
@@ -1572,7 +1591,7 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
15721591
auto processSpecInfo = [&](const IGPUPipelineBase::SShaderSpecInfo& spec, hlsl::ShaderStage shaderStage)
15731592
{
15741593
if (!spec.shader) return;
1575-
const auto key = VkShaderStageKey{ spec.shader, spec.entryPoint };
1594+
const auto key = ShaderModuleKey{ spec.shader, spec.entryPoint };
15761595
if (shaderIndexes.find(key) == shaderIndexes.end())
15771596
{
15781597
shaderIndexes.insert({ key , std::distance<decltype(outCreateInfo->pStages)>(outCreateInfo->pStages, outShaderStage)});

0 commit comments

Comments
 (0)