Skip to content

Commit 34ebc4c

Browse files
committed
First few migrations of the Scan algorithm modules
1 parent e03b142 commit 34ebc4c

File tree

3 files changed

+253
-1
lines changed

3 files changed

+253
-1
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
#ifndef _NBL_GLSL_SCAN_DEFAULT_SCHEDULER_INCLUDED_
2+
#define _NBL_GLSL_SCAN_DEFAULT_SCHEDULER_INCLUDED_
3+
4+
#include <nbl/builtin/glsl/scan/parameters_struct.hlsl>
5+
6+
#ifdef __cplusplus
7+
#define uint uint32_t
8+
#endif
9+
10+
const uint gl_LocalInvocationIndex: SV_GroupIndex;
11+
12+
namespace nbl
13+
{
14+
namespace hlsl
15+
{
16+
namespace scan
17+
{
18+
struct DefaultSchedulerParameters_t
19+
{
20+
uint finishedFlagOffset[NBL_BUILTIN_MAX_SCAN_LEVELS-1];
21+
uint cumulativeWorkgroupCount[NBL_BUILTIN_MAX_SCAN_LEVELS];
22+
23+
};
24+
}
25+
}
26+
}
27+
28+
#ifdef __cplusplus
29+
#undef uint
30+
#else
31+
32+
namespace nbl
33+
{
34+
namespace hlsl
35+
{
36+
namespace scan
37+
{
38+
namespace scheduler
39+
{
40+
/**
41+
* The CScanner.h parameter computation calculates the number of virtual workgroups that will have to be launched for the Scan operation
42+
* (always based on the elementCount) as well as different offsets for the results of each step of the Scan operation, flag positions
43+
* that are used for synchronization etc.
44+
* Remember that CScanner does a Blelloch Scan which works in levels. In each level of the Blelloch scan the array of elements is
45+
* broken down into sets of size=WorkgroupSize and each set is scanned using Hillis & Steele (aka Stone-Kogge adder). The result of
46+
* the scan is provided as an array element for the next level of the Blelloch Scan. This means that if we have 10000 elements and
47+
* WorkgroupSize=250, we will break the array into 40 sets and take their reduction results. The next level of the Blelloch Scan will
48+
* have an array of size 40. Only a single workgroup will be needed to work on that. After that array is scanned, we use the results
49+
* in the downsweep phase of Blelloch Scan.
50+
* Keep in mind that each virtual workgroup executes a single step of the whole algorithm, which is why we have the cumulativeWorkgroupCount.
51+
* The first virtual workgroups will work on the upsweep phase, the next on the downsweep phase.
52+
* The intermediate results are stored in a scratch buffer. That buffer's size is is the sum of the element-array size for all the
53+
* Blelloch levels. Using the previous example, the scratch size should be 10000 + 40.
54+
*
55+
* Parameter meaning:
56+
* |> lastElement - the index of the last element of each Blelloch level in the scratch buffer
57+
* |> topLevel - the top level the Blelloch Scan will have (this depends on the elementCount and the WorkgroupSize)
58+
* |> temporaryStorageOffset - an offset array for each level of the Blelloch Scan. It is used when storing the REDUCTION result of each workgroup scan
59+
* |> cumulativeWorkgroupCount - the sum-scan of all the workgroups that will need to be launched for each level of the Blelloch Scan (both upsweep and downsweep)
60+
* |> finishedFlagOffset - an index in the scratch buffer where each virtual workgroup indicates that ALL its invocations have finished their work. This helps
61+
* synchronizing between workgroups with while-loop spinning.
62+
*/
63+
void computeParameters(in uint elementCount, out Parameters_t _scanParams, out DefaultSchedulerParameters_t _schedulerParams)
64+
{
65+
#define WorkgroupCount(Level) (_scanParams.lastElement[Level+1]+1u)
66+
_scanParams.lastElement[0] = elementCount-1u;
67+
_scanParams.topLevel = firstbithigh(_scanParams.lastElement[0])/_NBL_HLSL_WORKGROUP_SIZE_LOG2_;
68+
// REVIEW: _NBL_HLSL_WORKGROUP_SIZE_LOG2_ is defined in files that include THIS file. Why not query the API for workgroup size at runtime?
69+
70+
for (uint i=0; i<NBL_BUILTIN_MAX_SCAN_LEVELS/2;)
71+
{
72+
const uint next = i+1;
73+
_scanParams.lastElement[next] = _scanParams.lastElement[i]>>_NBL_HLSL_WORKGROUP_SIZE_LOG2_;
74+
i = next;
75+
}
76+
_schedulerParams.cumulativeWorkgroupCount[0] = WorkgroupCount(0);
77+
_schedulerParams.finishedFlagOffset[0] = 0u;
78+
switch(_scanParams.topLevel)
79+
{
80+
case 1u:
81+
_schedulerParams.cumulativeWorkgroupCount[1] = _schedulerParams.cumulativeWorkgroupCount[0]+1u;
82+
_schedulerParams.cumulativeWorkgroupCount[2] = _schedulerParams.cumulativeWorkgroupCount[1]+WorkgroupCount(0);
83+
// climb up
84+
_schedulerParams.finishedFlagOffset[1] = 1u;
85+
86+
_scanParams.temporaryStorageOffset[0] = 2u;
87+
break;
88+
case 2u:
89+
_schedulerParams.cumulativeWorkgroupCount[1] = _schedulerParams.cumulativeWorkgroupCount[0]+WorkgroupCount(1);
90+
_schedulerParams.cumulativeWorkgroupCount[2] = _schedulerParams.cumulativeWorkgroupCount[1]+1u;
91+
_schedulerParams.cumulativeWorkgroupCount[3] = _schedulerParams.cumulativeWorkgroupCount[2]+WorkgroupCount(1);
92+
_schedulerParams.cumulativeWorkgroupCount[4] = _schedulerParams.cumulativeWorkgroupCount[3]+WorkgroupCount(0);
93+
// climb up
94+
_schedulerParams.finishedFlagOffset[1] = WorkgroupCount(1);
95+
_schedulerParams.finishedFlagOffset[2] = _schedulerParams.finishedFlagOffset[1]+1u;
96+
// climb down
97+
_schedulerParams.finishedFlagOffset[3] = _schedulerParams.finishedFlagOffset[1]+2u;
98+
99+
_scanParams.temporaryStorageOffset[0] = _schedulerParams.finishedFlagOffset[3]+WorkgroupCount(1);
100+
_scanParams.temporaryStorageOffset[1] = _scanParams.temporaryStorageOffset[0]+WorkgroupCount(0);
101+
break;
102+
case 3u:
103+
_schedulerParams.cumulativeWorkgroupCount[1] = _schedulerParams.cumulativeWorkgroupCount[0]+WorkgroupCount(1);
104+
_schedulerParams.cumulativeWorkgroupCount[2] = _schedulerParams.cumulativeWorkgroupCount[1]+WorkgroupCount(2);
105+
_schedulerParams.cumulativeWorkgroupCount[3] = _schedulerParams.cumulativeWorkgroupCount[2]+1u;
106+
_schedulerParams.cumulativeWorkgroupCount[4] = _schedulerParams.cumulativeWorkgroupCount[3]+WorkgroupCount(2);
107+
_schedulerParams.cumulativeWorkgroupCount[5] = _schedulerParams.cumulativeWorkgroupCount[4]+WorkgroupCount(1);
108+
_schedulerParams.cumulativeWorkgroupCount[6] = _schedulerParams.cumulativeWorkgroupCount[5]+WorkgroupCount(0);
109+
// climb up
110+
_schedulerParams.finishedFlagOffset[1] = WorkgroupCount(1);
111+
_schedulerParams.finishedFlagOffset[2] = _schedulerParams.finishedFlagOffset[1]+WorkgroupCount(2);
112+
_schedulerParams.finishedFlagOffset[3] = _schedulerParams.finishedFlagOffset[2]+1u;
113+
// climb down
114+
_schedulerParams.finishedFlagOffset[4] = _schedulerParams.finishedFlagOffset[2]+2u;
115+
_schedulerParams.finishedFlagOffset[5] = _schedulerParams.finishedFlagOffset[4]+WorkgroupCount(2);
116+
117+
_scanParams.temporaryStorageOffset[0] = _schedulerParams.finishedFlagOffset[5]+WorkgroupCount(1);
118+
_scanParams.temporaryStorageOffset[1] = _scanParams.temporaryStorageOffset[0]+WorkgroupCount(0);
119+
_scanParams.temporaryStorageOffset[2] = _scanParams.temporaryStorageOffset[1]+WorkgroupCount(1);
120+
break;
121+
default:
122+
break;
123+
#if NBL_BUILTIN_MAX_SCAN_LEVELS>7
124+
#error "Switch needs more cases"
125+
#endif
126+
}
127+
#undef WorkgroupCount
128+
}
129+
130+
/**
131+
* treeLevel - the current level in the Blelloch Scan
132+
* localWorkgroupIndex - the workgroup index the current invocation is a part of in the specific virtual dispatch.
133+
* For example, if we have dispatched 10 workgroups and we the virtual workgroup number is 35, then the localWorkgroupIndex should be 5.
134+
*/
135+
template<class ScratchAccessor>
136+
bool getWork(in DefaultSchedulerParameters_t params, in uint topLevel, out uint treeLevel, out uint localWorkgroupIndex)
137+
{
138+
ScratchAccessor sharedScratch;
139+
if(gl_LocalInvocationIndex == 0u)
140+
{
141+
uint64_t original;
142+
InterlockedAdd(scanScratch.workgroupsStarted, 1u, original); // TODO (PentaKon): Refactor this when the ScanScratch descriptor set is declared
143+
sharedScratch.set(gl_LocalInvocationIndex, original);
144+
}
145+
else if (gl_LocalInvocationIndex == 1u)
146+
{
147+
sharedScratch.set(gl_LocalInvocationIndex, 0u);
148+
}
149+
GroupMemoryBarrierWithGroupSync(); // REVIEW: refactor this somewhere with GLSL terminology?
150+
151+
const uint globalWorkgroupIndex; // does every thread need to know?
152+
sharedScratch.get(0u, globalWorkgroupIndex);
153+
const uint lastLevel = topLevel<<1u;
154+
if (gl_LocalInvocationIndex<=lastLevel && globalWorkgroupIndex>=params.cumulativeWorkgroupCount[gl_LocalInvocationIndex])
155+
{
156+
InterlockedAdd(sharedScratch.get(1u, ?), 1u); // REVIEW: The way scratchaccessoradaptor is implemented (e.g. under subgroup/arithmetic_portability) doesn't allow for atomic ops on the scratch buffer. Should we ask for another implementation that overrides the [] operator ?
157+
}
158+
GroupMemoryBarrierWithGroupSync(); // TODO (PentaKon): Possibly refactor?
159+
160+
sharedScratch.get(1u, treeLevel);
161+
if(treeLevel>lastLevel)
162+
return true;
163+
164+
localWorkgroupIndex = globalWorkgroupIndex;
165+
const bool dependentLevel = treeLevel != 0u;
166+
if(dependentLevel)
167+
{
168+
const uint prevLevel = treeLevel - 1u;
169+
localWorkgroupIndex -= params.cumulativeWorkgroupCount[prevLevel];
170+
if(gl_LocalInvocationIndex == 0u)
171+
{
172+
uint dependentsCount = 1u;
173+
if(treeLevel <= topLevel)
174+
{
175+
dependentsCount = _NBL_HLSL_WORKGROUP_SIZE_; // REVIEW: Defined in the files that include this file?
176+
const bool lastWorkgroup = (globalWorkgroupIndex+1u)==params.cumulativeWorkgroupCount[treeLevel];
177+
if (lastWorkgroup)
178+
{
179+
const Parameters_t scanParams = getParameters(); // TODO (PentaKon): Undeclared as of now, this should return the Parameters_t from the push constants of (in)direct shader
180+
dependentsCount = scanParams.lastElement[treeLevel]+1u;
181+
if (treeLevel<topLevel)
182+
{
183+
dependentsCount -= scanParams.lastElement[treeLevel+1u]*_NBL_HLSL_WORKGROUP_SIZE_;
184+
}
185+
}
186+
}
187+
uint dependentsFinishedFlagOffset = localWorkgroupIndex;
188+
if (treeLevel>topLevel) // !(prevLevel<topLevel) TODO: merge with `else` above?
189+
dependentsFinishedFlagOffset /= _NBL_HLSL_WORKGROUP_SIZE_;
190+
dependentsFinishedFlagOffset += params.finishedFlagOffset[prevLevel];
191+
while (scanScratch.data[dependentsFinishedFlagOffset]!=dependentsCount) // TODO (PentaKon): Refactor this when the ScanScratch descriptor set is declared
192+
GroupMemoryBarrierWithGroupSync();
193+
}
194+
}
195+
GroupMemoryBarrierWithGroupSync();
196+
return false;
197+
}
198+
199+
void markComplete(in DefaultSchedulerParameters_t params, in uint topLevel, in uint treeLevel, in uint localWorkgroupIndex)
200+
{
201+
GroupMemoryBarrierWithGroupSync(); // must complete writing the data before flags itself as complete
202+
if (gl_LocalInvocationIndex==0u)
203+
{
204+
uint finishedFlagOffset = params.finishedFlagOffset[treeLevel];
205+
if (treeLevel<topLevel)
206+
{
207+
finishedFlagOffset += localWorkgroupIndex/_NBL_HLSL_WORKGROUP_SIZE_;
208+
InterlockedAdd(scanScratch.data[finishedFlagOffset],1u);
209+
}
210+
else if (treeLevel!=(topLevel<<1u))
211+
{
212+
finishedFlagOffset += localWorkgroupIndex;
213+
scanScratch.data[finishedFlagOffset] = 1u; // TODO (PentaKon): Refactor this when the ScanScratch descriptor set is declared
214+
}
215+
}
216+
}
217+
}
218+
}
219+
}
220+
}
221+
#endif
222+
223+
#endif
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef _NBL_HLSL_SCAN_PARAMETERS_STRUCT_INCLUDED_
2+
#define _NBL_HLSL_SCAN_PARAMETERS_STRUCT_INCLUDED_
3+
4+
#define NBL_BUILTIN_MAX_SCAN_LEVELS 7
5+
6+
#ifdef __cplusplus
7+
#define uint uint32_t
8+
#endif
9+
10+
namespace nbl
11+
{
12+
namespace hlsl
13+
{
14+
namespace scan
15+
{
16+
struct Parameters_t {
17+
uint topLevel;
18+
uint lastElement[NBL_BUILTIN_MAX_SCAN_LEVELS/2+1];
19+
uint temporaryStorageOffset[NBL_BUILTIN_MAX_SCAN_LEVELS/2];
20+
}
21+
}
22+
}
23+
}
24+
25+
#ifdef __cplusplus
26+
#undef uint
27+
#endif
28+
29+
#endif

include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ struct reduction<binops::bitwise_max>
192192

193193
namespace portability
194194
{
195-
195+
// REVIEW: This seems like generic code, unrelated to subgroups. Should we move it to different module?
196196
template<class NumberScratchAccessor>
197197
struct ScratchAccessorAdaptor {
198198
NumberScratchAccessor accessor;

0 commit comments

Comments
 (0)