@@ -84,12 +84,15 @@ struct reduce_level0
84
84
using scalar_t = typename BinOp::type_t;
85
85
using vector_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
86
86
87
- template<class DataAccessor, class ScratchAccessor, class Params >
87
+ template<class DataAccessor, class ScratchAccessor>
88
88
static void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
89
89
{
90
+ using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
91
+ using params_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
92
+
90
93
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
91
94
// level 0 scan
92
- subgroup2::reduction<Params > reduction0;
95
+ subgroup2::reduction<params_t > reduction0;
93
96
[unroll]
94
97
for (uint16_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
95
98
{
@@ -112,11 +115,14 @@ struct scan_level0
112
115
using scalar_t = typename BinOp::type_t;
113
116
using vector_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
114
117
115
- template<class DataAccessor, class ScratchAccessor, class Params >
118
+ template<class DataAccessor, class ScratchAccessor>
116
119
static void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
117
120
{
121
+ using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
122
+ using params_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
123
+
118
124
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
119
- subgroup2::inclusive_scan<Params > inclusiveScan0;
125
+ subgroup2::inclusive_scan<params_t > inclusiveScan0;
120
126
// level 0 scan
121
127
[unroll]
122
128
for (uint16_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
@@ -147,11 +153,10 @@ struct reduce<Config, BinOp, 2, device_capabilities>
147
153
scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
148
154
{
149
155
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
150
- using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
151
156
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
152
157
BinOp binop;
153
158
154
- reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t >(dataAccessor, scratchAccessor);
159
+ reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
155
160
156
161
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
157
162
// level 1 scan
@@ -186,11 +191,10 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
186
191
void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
187
192
{
188
193
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
189
- using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
190
194
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
191
195
BinOp binop;
192
196
193
- scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t >(dataAccessor, scratchAccessor);
197
+ scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
194
198
195
199
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
196
200
// level 1 scan
@@ -216,11 +220,9 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
216
220
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
217
221
218
222
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t (glsl::gl_SubgroupID ()-1u), idx);
219
- scalar_t left;
223
+ scalar_t left = BinOp::identity ;
220
224
if (idx != 0 || glsl::gl_SubgroupID () != 0 )
221
225
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
222
- else
223
- left = BinOp::identity;
224
226
if (Exclusive)
225
227
{
226
228
scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
@@ -253,12 +255,11 @@ struct reduce<Config, BinOp, 3, device_capabilities>
253
255
scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
254
256
{
255
257
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
256
- using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
257
258
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
258
259
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
259
260
BinOp binop;
260
261
261
- reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t >(dataAccessor, scratchAccessor);
262
+ reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
262
263
263
264
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
264
265
// level 1 scan
@@ -310,12 +311,11 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
310
311
void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
311
312
{
312
313
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
313
- using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
314
314
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
315
315
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
316
316
BinOp binop;
317
317
318
- scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t >(dataAccessor, scratchAccessor);
318
+ scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
319
319
320
320
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
321
321
// level 1 scan
@@ -357,12 +357,10 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
357
357
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
358
358
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), lv1_val[i]);
359
359
360
- scalar_t lv2_scan;
360
+ scalar_t lv2_scan = BinOp::identity ;
361
361
const uint16_t bankedIndex = Config::template sharedStoreIndex<2 >(uint16_t (glsl::gl_SubgroupID ()-1u));
362
362
if (glsl::gl_SubgroupID () != 0 )
363
363
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex, lv2_scan);
364
- else
365
- lv2_scan = BinOp::identity;
366
364
367
365
[unroll]
368
366
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
@@ -378,11 +376,9 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
378
376
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
379
377
380
378
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t (glsl::gl_SubgroupID ()-1u), idx);
381
- scalar_t left;
379
+ scalar_t left = BinOp::identity ;
382
380
if (idx != 0 || glsl::gl_SubgroupID () != 0 )
383
381
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
384
- else
385
- left = BinOp::identity;
386
382
if (Exclusive)
387
383
{
388
384
scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
0 commit comments