Skip to content

Commit bbcff69

Browse files
author
kevyuu
committed
Implement get ray tracing stack size api
1 parent 266d0eb commit bbcff69

File tree

3 files changed

+148
-6
lines changed

3 files changed

+148
-6
lines changed

include/nbl/video/CVulkanRayTracingPipeline.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class CVulkanRayTracingPipeline final : public IGPURayTracingPipeline
1414
{
1515
using ShaderRef = core::smart_refctd_ptr<const CVulkanShader>;
1616
using ShaderContainer = core::smart_refctd_dynamic_array<ShaderRef>;
17+
using GeneralGroupStackSizeContainer = core::smart_refctd_dynamic_array<uint16_t>;
18+
using HitGroupStackSizeContainer = core::smart_refctd_dynamic_array<SHitGroupStackSize>;
1719

1820
public:
1921

@@ -33,12 +35,27 @@ class CVulkanRayTracingPipeline final : public IGPURayTracingPipeline
3335
virtual const SShaderGroupHandle& getHit(uint32_t index) const override;
3436
virtual const SShaderGroupHandle& getCallable(uint32_t index) const override;
3537

38+
virtual uint16_t getRaygenStackSize() const override;
39+
virtual std::span<const uint16_t> getMissStackSizes() const override;
40+
virtual std::span<const SHitGroupStackSize> getHitStackSizes() const override;
41+
virtual std::span<const uint16_t> getCallableStackSizes() const override;
42+
virtual uint16_t getDefaultStackSize() const override;
43+
3644
private:
3745
~CVulkanRayTracingPipeline() override;
3846

3947
const VkPipeline m_vkPipeline;
4048
ShaderContainer m_shaders;
4149
ShaderGroupHandleContainer m_shaderGroupHandles;
50+
uint16_t m_raygenStackSize;
51+
core::smart_refctd_dynamic_array<uint16_t> m_missStackSizes;
52+
core::smart_refctd_dynamic_array<SHitGroupStackSize> m_hitGroupStackSizes;
53+
core::smart_refctd_dynamic_array<uint16_t> m_callableStackSizes;
54+
55+
uint32_t getRaygenIndex() const;
56+
uint32_t getMissBaseIndex() const;
57+
uint32_t getHitBaseIndex() const;
58+
uint32_t getCallableBaseIndex() const;
4259
};
4360

4461
}

include/nbl/video/IGPURayTracingPipeline.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ class IGPURayTracingPipeline : public IBackendObject, public asset::IRayTracingP
2323
};
2424
static_assert(sizeof(SShaderGroupHandle) == video::SPhysicalDeviceLimits::ShaderGroupHandleSize);
2525

26+
struct SHitGroupStackSize
27+
{
28+
uint16_t closestHit;
29+
uint16_t anyHit;
30+
uint16_t intersection;
31+
};
32+
2633
struct SCreationParams final : pipeline_t::SCreationParams, SPipelineCreationParams<const IGPURayTracingPipeline>
2734
{
2835

@@ -68,6 +75,12 @@ class IGPURayTracingPipeline : public IBackendObject, public asset::IRayTracingP
6875
virtual const SShaderGroupHandle& getHit(uint32_t index) const = 0;
6976
virtual const SShaderGroupHandle& getCallable(uint32_t index) const = 0;
7077

78+
virtual uint16_t getRaygenStackSize() const = 0;
79+
virtual std::span<const uint16_t> getMissStackSizes() const = 0;
80+
virtual std::span<const SHitGroupStackSize> getHitStackSizes() const = 0;
81+
virtual std::span<const uint16_t> getCallableStackSizes() const = 0;
82+
virtual uint16_t getDefaultStackSize() const = 0;
83+
7184
protected:
7285
IGPURayTracingPipeline(const SCreationParams& params) : IBackendObject(core::smart_refctd_ptr<const ILogicalDevice>(params.layout->getOriginDevice())),
7386
pipeline_t(params),
Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
#include "nbl/video/CVulkanRayTracingPipeline.h"
1+
#include "nbl/asset/IRayTracingPipeline.h"
22

3+
#include "nbl/video/CVulkanRayTracingPipeline.h"
34
#include "nbl/video/CVulkanLogicalDevice.h"
5+
#include "nbl/video/IGPURayTracingPipeline.h"
6+
7+
#include <algorithm>
48

59
namespace nbl::video
610
{
@@ -12,11 +16,60 @@ namespace nbl::video
1216
IGPURayTracingPipeline(params),
1317
m_vkPipeline(vk_pipeline),
1418
m_shaders(core::make_refctd_dynamic_array<ShaderContainer>(params.shaders.size())),
19+
m_missStackSizes(core::make_refctd_dynamic_array<GeneralGroupStackSizeContainer>(params.shaderGroups.misses.size())),
20+
m_hitGroupStackSizes(core::make_refctd_dynamic_array<HitGroupStackSizeContainer>(params.shaderGroups.hits.size())),
21+
m_callableStackSizes(core::make_refctd_dynamic_array<GeneralGroupStackSizeContainer>(params.shaderGroups.hits.size())),
1522
m_shaderGroupHandles(std::move(shaderGroupHandles))
1623
{
1724
for (size_t shaderIx = 0; shaderIx < params.shaders.size(); shaderIx++)
1825
m_shaders->operator[](shaderIx) = ShaderRef(static_cast<const CVulkanShader*>(params.shaders[shaderIx].shader));
1926

27+
const auto* vulkanDevice = static_cast<const CVulkanLogicalDevice*>(getOriginDevice());
28+
auto* vk = vulkanDevice->getFunctionTable();
29+
30+
auto getVkShaderGroupStackSize = [&](uint32_t baseGroupIx, uint32_t shaderGroupIx, uint32_t shaderIx, VkShaderGroupShaderKHR shaderType) -> uint16_t
31+
{
32+
if (shaderIx == SShaderGroupsParams::SIndex::Unused)
33+
return 0;
34+
35+
return vk->vk.vkGetRayTracingShaderGroupStackSizeKHR(
36+
vulkanDevice->getInternalObject(),
37+
m_vkPipeline,
38+
baseGroupIx + shaderGroupIx,
39+
shaderType
40+
);
41+
};
42+
43+
m_raygenStackSize = getVkShaderGroupStackSize(getRaygenIndex(), 0, params.shaderGroups.raygen.index, VK_SHADER_GROUP_SHADER_GENERAL_KHR);
44+
45+
for (size_t shaderGroupIx = 0; shaderGroupIx < params.shaderGroups.misses.size(); shaderGroupIx++)
46+
{
47+
m_missStackSizes->operator[](shaderGroupIx) = getVkShaderGroupStackSize(
48+
getMissBaseIndex(),
49+
shaderGroupIx,
50+
params.shaderGroups.misses[shaderGroupIx].index,
51+
VK_SHADER_GROUP_SHADER_GENERAL_KHR);
52+
}
53+
54+
for (size_t shaderGroupIx = 0; shaderGroupIx < params.shaderGroups.hits.size(); shaderGroupIx++)
55+
{
56+
const auto& hitGroup = params.shaderGroups.hits[shaderGroupIx];
57+
const auto baseIndex = getHitBaseIndex();
58+
m_hitGroupStackSizes->operator[](shaderGroupIx) = SHitGroupStackSize{
59+
.closestHit = getVkShaderGroupStackSize(baseIndex,shaderGroupIx, hitGroup.closestHit, VK_SHADER_GROUP_SHADER_CLOSEST_HIT_KHR),
60+
.anyHit = getVkShaderGroupStackSize(baseIndex, shaderGroupIx, hitGroup.anyHit,VK_SHADER_GROUP_SHADER_ANY_HIT_KHR),
61+
.intersection = getVkShaderGroupStackSize(baseIndex, shaderGroupIx, hitGroup.intersection, VK_SHADER_GROUP_SHADER_INTERSECTION_KHR),
62+
};
63+
}
64+
65+
for (size_t shaderGroupIx = 0; shaderGroupIx < params.shaderGroups.callables.size(); shaderGroupIx++)
66+
{
67+
m_callableStackSizes->operator[](shaderGroupIx) = getVkShaderGroupStackSize(
68+
getCallableBaseIndex(),
69+
shaderGroupIx,
70+
params.shaderGroups.callables[shaderGroupIx].index,
71+
VK_SHADER_GROUP_SHADER_GENERAL_KHR);
72+
}
2073
}
2174

2275
CVulkanRayTracingPipeline::~CVulkanRayTracingPipeline()
@@ -26,27 +79,86 @@ namespace nbl::video
2679
vk->vk.vkDestroyPipeline(vulkanDevice->getInternalObject(), m_vkPipeline, nullptr);
2780
}
2881

29-
3082
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getRaygen() const
3183
{
32-
return m_shaderGroupHandles->operator[](0);
84+
return m_shaderGroupHandles->operator[](getRaygenIndex());
3385
}
3486

3587
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getMiss(uint32_t index) const
3688
{
37-
const auto baseIndex = 1; // one raygen group before this groups
89+
const auto baseIndex = getMissBaseIndex();
3890
return m_shaderGroupHandles->operator[](baseIndex + index);
3991
}
4092

4193
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getHit(uint32_t index) const
4294
{
43-
const auto baseIndex = 1 + getMissGroupCount(); // one raygen group + miss gropus before this groups
95+
const auto baseIndex = getHitBaseIndex();
4496
return m_shaderGroupHandles->operator[](baseIndex + index);
4597
}
4698

4799
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getCallable(uint32_t index) const
48100
{
49-
const auto baseIndex = 1 + getMissGroupCount() + getHitGroupCount(); // one raygen group + miss groups + hit gropus before this groups
101+
const auto baseIndex = getCallableBaseIndex();
50102
return m_shaderGroupHandles->operator[](baseIndex + index);
51103
}
104+
105+
uint16_t CVulkanRayTracingPipeline::getRaygenStackSize() const
106+
{
107+
return m_raygenStackSize;
108+
}
109+
110+
std::span<const uint16_t> CVulkanRayTracingPipeline::getMissStackSizes() const
111+
{
112+
return std::span(m_missStackSizes->begin(), m_missStackSizes->end());
113+
}
114+
115+
std::span<const IGPURayTracingPipeline::SHitGroupStackSize> CVulkanRayTracingPipeline::getHitStackSizes() const
116+
{
117+
return std::span(m_hitGroupStackSizes->begin(), m_hitGroupStackSizes->end());
118+
}
119+
120+
std::span<const uint16_t> CVulkanRayTracingPipeline::getCallableStackSizes() const
121+
{
122+
return std::span(m_callableStackSizes->begin(), m_callableStackSizes->end());
123+
}
124+
125+
uint16_t CVulkanRayTracingPipeline::getDefaultStackSize() const
126+
{
127+
// calculation follow the formula from
128+
// https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#ray-tracing-pipeline-stack
129+
const auto raygenStackMax = m_raygenStackSize;
130+
const auto closestHitStackMax = std::ranges::max_element(getHitStackSizes(), std::ranges::less{}, &SHitGroupStackSize::closestHit)->closestHit;
131+
const auto anyHitStackMax = std::ranges::max_element(getHitStackSizes(), std::ranges::less{}, &SHitGroupStackSize::anyHit)->anyHit;
132+
const auto intersectionStackMax = std::ranges::max_element(getHitStackSizes(), std::ranges::less{}, &SHitGroupStackSize::intersection)->intersection;
133+
const auto missStackMax = *std::ranges::max_element(getMissStackSizes());
134+
const auto callableStackMax = *std::ranges::max_element(getCallableStackSizes());
135+
return raygenStackMax + std::min<uint16_t>(1, m_params.maxRecursionDepth) *
136+
std::max(closestHitStackMax, std::max<uint16_t>(missStackMax, intersectionStackMax + anyHitStackMax)) +
137+
std::max<uint16_t>(0, m_params.maxRecursionDepth - 1) * std::max(closestHitStackMax, missStackMax) + 2 *
138+
callableStackMax;
139+
}
140+
141+
uint32_t CVulkanRayTracingPipeline::getRaygenIndex() const
142+
{
143+
return 0;
144+
}
145+
146+
uint32_t CVulkanRayTracingPipeline::getMissBaseIndex() const
147+
{
148+
// one raygen group before this groups
149+
return 1;
150+
}
151+
152+
uint32_t CVulkanRayTracingPipeline::getHitBaseIndex() const
153+
{
154+
// one raygen group + miss groups before this groups
155+
return 1 + getMissGroupCount();
156+
}
157+
158+
uint32_t CVulkanRayTracingPipeline::getCallableBaseIndex() const
159+
{
160+
// one raygen group + miss groups + hit groups before this groups
161+
return 1 + getMissGroupCount() + getHitGroupCount();
162+
}
163+
52164
}

0 commit comments

Comments
 (0)