Skip to content

Commit 3f0962c

Browse files
author
kevyuu
committed
Improve ShaderGroupParams naming
1 parent 5114b09 commit 3f0962c

File tree

2 files changed

+39
-40
lines changed

2 files changed

+39
-40
lines changed

include/nbl/asset/IRayTracingPipeline.h

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,32 @@ class IRayTracingPipelineBase : public virtual core::IReferenceCounted
1515
public:
1616
struct SShaderGroupsParams
1717
{
18-
constexpr static inline uint32_t ShaderUnused = 0xffFFffFFu;
19-
20-
struct SGeneralShaderGroup
18+
struct SIndex
2119
{
22-
uint32_t shaderIndex = ShaderUnused;
20+
constexpr static inline uint32_t Unused = 0xffFFffFFu;
21+
uint32_t index = Unused;
2322
};
2423

25-
struct SHitShaderGroup
24+
struct SHitGroup
2625
{
27-
uint32_t closestHitShaderIndex = ShaderUnused;
28-
uint32_t anyHitShaderIndex = ShaderUnused;
29-
uint32_t intersectionShaderIndex = ShaderUnused;
26+
uint32_t closestHit = SIndex::Unused;
27+
uint32_t anyHit = SIndex::Unused;
28+
uint32_t intersectionShader = SIndex::Unused;
3029
};
3130

32-
SGeneralShaderGroup raygenGroup;
33-
std::span<SGeneralShaderGroup> missGroups;
34-
std::span<SHitShaderGroup> hitGroups;
35-
std::span<SGeneralShaderGroup> callableGroups;
31+
SIndex raygen;
32+
std::span<SIndex> misses;
33+
std::span<SHitGroup> hits;
34+
std::span<SIndex> callables;
3635

3736
inline uint32_t getShaderGroupCount() const
3837
{
39-
return 1 + hitGroups.size() + missGroups.size() + callableGroups.size();
38+
return 1 + hits.size() + misses.size() + callables.size();
4039
}
4140

4241
};
43-
using SGeneralShaderGroup = SShaderGroupsParams::SGeneralShaderGroup;
44-
using SHitShaderGroup = SShaderGroupsParams::SHitShaderGroup;
42+
using SGeneralShaderGroup = SShaderGroupsParams::SIndex;
43+
using SHitShaderGroup = SShaderGroupsParams::SHitGroup;
4544

4645
struct SCachedCreationParams final
4746
{
@@ -86,14 +85,14 @@ class IRayTracingPipeline : public IPipeline<PipelineLayoutType>, public IRayTra
8685
return shaders[index].shader->getStage();
8786
};
8887

89-
if (shaderGroups.raygenGroup.shaderIndex >= shaders.size())
88+
if (shaderGroups.raygen.index >= shaders.size())
9089
return false;
91-
if (getShaderStage(shaderGroups.raygenGroup.shaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_RAYGEN)
90+
if (getShaderStage(shaderGroups.raygen.index) != ICPUShader::E_SHADER_STAGE::ESS_RAYGEN)
9291
return false;
9392

9493
auto isValidShaderIndex = [this, getShaderStage](size_t index, ICPUShader::E_SHADER_STAGE expectedStage) -> bool
9594
{
96-
if (index == SShaderGroupsParams::ShaderUnused)
95+
if (index == SShaderGroupsParams::SIndex::Unused)
9796
return true;
9897
if (index >= shaders.size())
9998
return false;
@@ -102,27 +101,27 @@ class IRayTracingPipeline : public IPipeline<PipelineLayoutType>, public IRayTra
102101
return true;
103102
};
104103

105-
for (const auto& shaderGroup : shaderGroups.hitGroups)
104+
for (const auto& shaderGroup : shaderGroups.hits)
106105
{
107-
if (!isValidShaderIndex(shaderGroup.anyHitShaderIndex, ICPUShader::E_SHADER_STAGE::ESS_ANY_HIT))
106+
if (!isValidShaderIndex(shaderGroup.anyHit, ICPUShader::E_SHADER_STAGE::ESS_ANY_HIT))
108107
return false;
109108

110-
if (!isValidShaderIndex(shaderGroup.closestHitShaderIndex, ICPUShader::E_SHADER_STAGE::ESS_CLOSEST_HIT))
109+
if (!isValidShaderIndex(shaderGroup.closestHit, ICPUShader::E_SHADER_STAGE::ESS_CLOSEST_HIT))
111110
return false;
112111

113-
if (!isValidShaderIndex(shaderGroup.intersectionShaderIndex, ICPUShader::E_SHADER_STAGE::ESS_INTERSECTION))
112+
if (!isValidShaderIndex(shaderGroup.intersectionShader, ICPUShader::E_SHADER_STAGE::ESS_INTERSECTION))
114113
return false;
115114
}
116115

117-
for (const auto& shaderGroup : shaderGroups.missGroups)
116+
for (const auto& shaderGroup : shaderGroups.misses)
118117
{
119-
if (!isValidShaderIndex(shaderGroup.shaderIndex, ICPUShader::E_SHADER_STAGE::ESS_MISS))
118+
if (!isValidShaderIndex(shaderGroup.index, ICPUShader::E_SHADER_STAGE::ESS_MISS))
120119
return false;
121120
}
122121

123-
for (const auto& shaderGroup : shaderGroups.callableGroups)
122+
for (const auto& shaderGroup : shaderGroups.callables)
124123
{
125-
if (!isValidShaderIndex(shaderGroup.shaderIndex, ICPUShader::E_SHADER_STAGE::ESS_CALLABLE))
124+
if (!isValidShaderIndex(shaderGroup.index, ICPUShader::E_SHADER_STAGE::ESS_CALLABLE))
126125
return false;
127126
}
128127
return true;
@@ -153,10 +152,10 @@ class IRayTracingPipeline : public IPipeline<PipelineLayoutType>, public IRayTra
153152
explicit IRayTracingPipeline(const SCreationParams& _params) :
154153
IPipeline<PipelineLayoutType>(core::smart_refctd_ptr<const PipelineLayoutType>(_params.layout)),
155154
m_params(_params.cached),
156-
m_raygenShaderGroup(_params.shaderGroups.raygenGroup),
157-
m_missShaderGroups(core::make_refctd_dynamic_array<SGeneralShaderGroupContainer>(_params.shaderGroups.missGroups)),
158-
m_hitShaderGroups(core::make_refctd_dynamic_array<SHitShaderGroupContainer>(_params.shaderGroups.hitGroups)),
159-
m_callableShaderGroups(core::make_refctd_dynamic_array<SGeneralShaderGroupContainer>(_params.shaderGroups.callableGroups))
155+
m_raygenShaderGroup(_params.shaderGroups.raygen),
156+
m_missShaderGroups(core::make_refctd_dynamic_array<SGeneralShaderGroupContainer>(_params.shaderGroups.misses)),
157+
m_hitShaderGroups(core::make_refctd_dynamic_array<SHitShaderGroupContainer>(_params.shaderGroups.hits)),
158+
m_callableShaderGroups(core::make_refctd_dynamic_array<SGeneralShaderGroupContainer>(_params.shaderGroups.callables))
160159
{}
161160

162161
SCachedCreationParams m_params;

src/nbl/video/CVulkanLogicalDevice.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,7 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
14651465
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
14661466
.pNext = nullptr,
14671467
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
1468-
.generalShader = getVkShaderIndex(group.shaderIndex),
1468+
.generalShader = getVkShaderIndex(group.index),
14691469
.closestHitShader = VK_SHADER_UNUSED_KHR,
14701470
.anyHitShader = VK_SHADER_UNUSED_KHR,
14711471
.intersectionShader = VK_SHADER_UNUSED_KHR,
@@ -1476,12 +1476,12 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
14761476
return {
14771477
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
14781478
.pNext = nullptr,
1479-
.type = group.intersectionShaderIndex == SShaderGroupParams::ShaderUnused ?
1479+
.type = group.intersectionShader == SShaderGroupParams::SIndex::Unused ?
14801480
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR,
14811481
.generalShader = VK_SHADER_UNUSED_KHR,
1482-
.closestHitShader = getVkShaderIndex(group.closestHitShaderIndex),
1483-
.anyHitShader = getVkShaderIndex(group.anyHitShaderIndex),
1484-
.intersectionShader = getVkShaderIndex(group.intersectionShaderIndex),
1482+
.closestHitShader = getVkShaderIndex(group.closestHit),
1483+
.anyHitShader = getVkShaderIndex(group.anyHit),
1484+
.intersectionShader = getVkShaderIndex(group.intersectionShader),
14851485
};
14861486
};
14871487
for (const auto& info : createInfos)
@@ -1499,14 +1499,14 @@ void CVulkanLogicalDevice::createRayTracingPipelines_impl(
14991499

15001500
const auto& shaderGroups = info.shaderGroups;
15011501
outCreateInfo->pGroups = outShaderGroup;
1502-
*(outShaderGroup++) = getGeneralVkRayTracingShaderGroupCreateInfo(shaderGroups.raygenGroup);
1503-
for (const auto& shaderGroup : shaderGroups.missGroups)
1502+
*(outShaderGroup++) = getGeneralVkRayTracingShaderGroupCreateInfo(shaderGroups.raygen);
1503+
for (const auto& shaderGroup : shaderGroups.misses)
15041504
*(outShaderGroup++) = getGeneralVkRayTracingShaderGroupCreateInfo(shaderGroup);
1505-
for (const auto& shaderGroup : shaderGroups.hitGroups)
1505+
for (const auto& shaderGroup : shaderGroups.hits)
15061506
*(outShaderGroup++) = getHitVkRayTracingShaderGroupCreateInfo(shaderGroup);
1507-
for (const auto& shaderGroup : shaderGroups.callableGroups)
1507+
for (const auto& shaderGroup : shaderGroups.callables)
15081508
*(outShaderGroup++) = getGeneralVkRayTracingShaderGroupCreateInfo(shaderGroup);
1509-
outCreateInfo->groupCount = 1 + shaderGroups.hitGroups.size() + shaderGroups.missGroups.size() + shaderGroups.callableGroups.size();
1509+
outCreateInfo->groupCount = 1 + shaderGroups.hits.size() + shaderGroups.misses.size() + shaderGroups.callables.size();
15101510
outCreateInfo->maxPipelineRayRecursionDepth = info.cached.maxRecursionDepth;
15111511
}
15121512

0 commit comments

Comments
 (0)