1
- #include " nbl/video/CVulkanRayTracingPipeline .h"
1
+ #include " nbl/asset/IRayTracingPipeline .h"
2
2
3
+ #include " nbl/video/CVulkanRayTracingPipeline.h"
3
4
#include " nbl/video/CVulkanLogicalDevice.h"
5
+ #include " nbl/video/IGPURayTracingPipeline.h"
6
+
7
+ #include < algorithm>
4
8
5
9
namespace nbl ::video
6
10
{
@@ -12,11 +16,60 @@ namespace nbl::video
12
16
IGPURayTracingPipeline (params),
13
17
m_vkPipeline (vk_pipeline),
14
18
m_shaders (core::make_refctd_dynamic_array<ShaderContainer>(params.shaders.size())),
19
+ m_missStackSizes (core::make_refctd_dynamic_array<GeneralGroupStackSizeContainer>(params.shaderGroups.misses.size())),
20
+ m_hitGroupStackSizes (core::make_refctd_dynamic_array<HitGroupStackSizeContainer>(params.shaderGroups.hits.size())),
21
+ m_callableStackSizes (core::make_refctd_dynamic_array<GeneralGroupStackSizeContainer>(params.shaderGroups.hits.size())),
15
22
m_shaderGroupHandles (std::move(shaderGroupHandles))
16
23
{
17
24
for (size_t shaderIx = 0 ; shaderIx < params.shaders .size (); shaderIx++)
18
25
m_shaders->operator [](shaderIx) = ShaderRef (static_cast <const CVulkanShader*>(params.shaders [shaderIx].shader ));
19
26
27
+ const auto * vulkanDevice = static_cast <const CVulkanLogicalDevice*>(getOriginDevice ());
28
+ auto * vk = vulkanDevice->getFunctionTable ();
29
+
30
+ auto getVkShaderGroupStackSize = [&](uint32_t baseGroupIx, uint32_t shaderGroupIx, uint32_t shaderIx, VkShaderGroupShaderKHR shaderType) -> uint16_t
31
+ {
32
+ if (shaderIx == SShaderGroupsParams::SIndex::Unused)
33
+ return 0 ;
34
+
35
+ return vk->vk .vkGetRayTracingShaderGroupStackSizeKHR (
36
+ vulkanDevice->getInternalObject (),
37
+ m_vkPipeline,
38
+ baseGroupIx + shaderGroupIx,
39
+ shaderType
40
+ );
41
+ };
42
+
43
+ m_raygenStackSize = getVkShaderGroupStackSize (getRaygenIndex (), 0 , params.shaderGroups .raygen .index , VK_SHADER_GROUP_SHADER_GENERAL_KHR);
44
+
45
+ for (size_t shaderGroupIx = 0 ; shaderGroupIx < params.shaderGroups .misses .size (); shaderGroupIx++)
46
+ {
47
+ m_missStackSizes->operator [](shaderGroupIx) = getVkShaderGroupStackSize (
48
+ getMissBaseIndex (),
49
+ shaderGroupIx,
50
+ params.shaderGroups .misses [shaderGroupIx].index ,
51
+ VK_SHADER_GROUP_SHADER_GENERAL_KHR);
52
+ }
53
+
54
+ for (size_t shaderGroupIx = 0 ; shaderGroupIx < params.shaderGroups .hits .size (); shaderGroupIx++)
55
+ {
56
+ const auto & hitGroup = params.shaderGroups .hits [shaderGroupIx];
57
+ const auto baseIndex = getHitBaseIndex ();
58
+ m_hitGroupStackSizes->operator [](shaderGroupIx) = SHitGroupStackSize{
59
+ .closestHit = getVkShaderGroupStackSize (baseIndex,shaderGroupIx, hitGroup.closestHit , VK_SHADER_GROUP_SHADER_CLOSEST_HIT_KHR),
60
+ .anyHit = getVkShaderGroupStackSize (baseIndex, shaderGroupIx, hitGroup.anyHit ,VK_SHADER_GROUP_SHADER_ANY_HIT_KHR),
61
+ .intersection = getVkShaderGroupStackSize (baseIndex, shaderGroupIx, hitGroup.intersection , VK_SHADER_GROUP_SHADER_INTERSECTION_KHR),
62
+ };
63
+ }
64
+
65
+ for (size_t shaderGroupIx = 0 ; shaderGroupIx < params.shaderGroups .callables .size (); shaderGroupIx++)
66
+ {
67
+ m_callableStackSizes->operator [](shaderGroupIx) = getVkShaderGroupStackSize (
68
+ getCallableBaseIndex (),
69
+ shaderGroupIx,
70
+ params.shaderGroups .callables [shaderGroupIx].index ,
71
+ VK_SHADER_GROUP_SHADER_GENERAL_KHR);
72
+ }
20
73
}
21
74
22
75
CVulkanRayTracingPipeline::~CVulkanRayTracingPipeline ()
@@ -26,27 +79,86 @@ namespace nbl::video
26
79
vk->vk .vkDestroyPipeline (vulkanDevice->getInternalObject (), m_vkPipeline, nullptr );
27
80
}
28
81
29
-
30
82
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getRaygen () const
31
83
{
32
- return m_shaderGroupHandles->operator [](0 );
84
+ return m_shaderGroupHandles->operator [](getRaygenIndex () );
33
85
}
34
86
35
87
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getMiss (uint32_t index) const
36
88
{
37
- const auto baseIndex = 1 ; // one raygen group before this groups
89
+ const auto baseIndex = getMissBaseIndex ();
38
90
return m_shaderGroupHandles->operator [](baseIndex + index);
39
91
}
40
92
41
93
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getHit (uint32_t index) const
42
94
{
43
- const auto baseIndex = 1 + getMissGroupCount (); // one raygen group + miss gropus before this groups
95
+ const auto baseIndex = getHitBaseIndex ();
44
96
return m_shaderGroupHandles->operator [](baseIndex + index);
45
97
}
46
98
47
99
const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getCallable (uint32_t index) const
48
100
{
49
- const auto baseIndex = 1 + getMissGroupCount () + getHitGroupCount (); // one raygen group + miss groups + hit gropus before this groups
101
+ const auto baseIndex = getCallableBaseIndex ();
50
102
return m_shaderGroupHandles->operator [](baseIndex + index);
51
103
}
104
+
105
+ uint16_t CVulkanRayTracingPipeline::getRaygenStackSize () const
106
+ {
107
+ return m_raygenStackSize;
108
+ }
109
+
110
+ std::span<const uint16_t > CVulkanRayTracingPipeline::getMissStackSizes () const
111
+ {
112
+ return std::span (m_missStackSizes->begin (), m_missStackSizes->end ());
113
+ }
114
+
115
+ std::span<const IGPURayTracingPipeline::SHitGroupStackSize> CVulkanRayTracingPipeline::getHitStackSizes () const
116
+ {
117
+ return std::span (m_hitGroupStackSizes->begin (), m_hitGroupStackSizes->end ());
118
+ }
119
+
120
+ std::span<const uint16_t > CVulkanRayTracingPipeline::getCallableStackSizes () const
121
+ {
122
+ return std::span (m_callableStackSizes->begin (), m_callableStackSizes->end ());
123
+ }
124
+
125
+ uint16_t CVulkanRayTracingPipeline::getDefaultStackSize () const
126
+ {
127
+ // calculation follow the formula from
128
+ // https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#ray-tracing-pipeline-stack
129
+ const auto raygenStackMax = m_raygenStackSize;
130
+ const auto closestHitStackMax = std::ranges::max_element (getHitStackSizes (), std::ranges::less{}, &SHitGroupStackSize::closestHit)->closestHit ;
131
+ const auto anyHitStackMax = std::ranges::max_element (getHitStackSizes (), std::ranges::less{}, &SHitGroupStackSize::anyHit)->anyHit ;
132
+ const auto intersectionStackMax = std::ranges::max_element (getHitStackSizes (), std::ranges::less{}, &SHitGroupStackSize::intersection)->intersection ;
133
+ const auto missStackMax = *std::ranges::max_element (getMissStackSizes ());
134
+ const auto callableStackMax = *std::ranges::max_element (getCallableStackSizes ());
135
+ return raygenStackMax + std::min<uint16_t >(1 , m_params.maxRecursionDepth ) *
136
+ std::max (closestHitStackMax, std::max<uint16_t >(missStackMax, intersectionStackMax + anyHitStackMax)) +
137
+ std::max<uint16_t >(0 , m_params.maxRecursionDepth - 1 ) * std::max (closestHitStackMax, missStackMax) + 2 *
138
+ callableStackMax;
139
+ }
140
+
141
+ uint32_t CVulkanRayTracingPipeline::getRaygenIndex () const
142
+ {
143
+ return 0 ;
144
+ }
145
+
146
+ uint32_t CVulkanRayTracingPipeline::getMissBaseIndex () const
147
+ {
148
+ // one raygen group before this groups
149
+ return 1 ;
150
+ }
151
+
152
+ uint32_t CVulkanRayTracingPipeline::getHitBaseIndex () const
153
+ {
154
+ // one raygen group + miss groups before this groups
155
+ return 1 + getMissGroupCount ();
156
+ }
157
+
158
+ uint32_t CVulkanRayTracingPipeline::getCallableBaseIndex () const
159
+ {
160
+ // one raygen group + miss groups + hit groups before this groups
161
+ return 1 + getMissGroupCount () + getHitGroupCount ();
162
+ }
163
+
52
164
}
0 commit comments