10
10
#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11
11
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
12
12
13
+ #include "nbl/builtin/hlsl/algorithm.hlsl"
13
14
#include "nbl/builtin/hlsl/functional.hlsl"
14
15
#include "nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl"
15
16
@@ -138,35 +139,31 @@ SPECIALIZE_ALL(maximum,Max);
138
139
#undef SPECIALIZE_ALL
139
140
#undef SPECIALIZE
140
141
141
- template<class BinOp, uint16_t begin, uint16_t end >
142
+ template<class BinOp>
142
143
struct inclusive_scan_impl
143
144
{
144
145
using scalar_t = typename BinOp::type_t;
145
146
146
- static scalar_t __call (scalar_t value )
147
+ static inclusive_scan_impl<BinOp> create (scalar_t _value )
147
148
{
148
- BinOp op;
149
- const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
150
- const uint32_t step = 1u << begin;
151
- scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
152
- scalar_t new_value = op (value, hlsl::mix (rhs, BinOp::identity, subgroupInvocation < step));
153
- return inclusive_scan_impl<BinOp,begin+1 ,end>::__call (new_value);
149
+ inclusive_scan_impl<BinOp> retval;
150
+ retval.value = _value;
151
+ retval.subgroupInvocation = glsl::gl_SubgroupInvocationID ();
152
+ return retval;
154
153
}
155
- };
156
154
157
- template<class BinOp, uint16_t end>
158
- struct inclusive_scan_impl<BinOp,end,end>
159
- {
160
- using scalar_t = typename BinOp::type_t;
161
-
162
- static scalar_t __call (scalar_t value)
155
+ template<uint16_t StepLog2>
156
+ void __call ()
163
157
{
164
158
BinOp op;
165
- const uint32_t subgroupInvocation = glsl:: gl_SubgroupInvocationID () ;
166
- const uint32_t step = 1u << end ;
159
+ const uint32_t step = 1u << StepLog2 ;
160
+ spirv:: controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone) ;
167
161
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
168
- return op (value, hlsl::mix (rhs, BinOp::identity, subgroupInvocation < step));
162
+ value = op (value, hlsl::mix (rhs, BinOp::identity, subgroupInvocation < step));
169
163
}
164
+
165
+ scalar_t value;
166
+ uint32_t subgroupInvocation;
170
167
};
171
168
172
169
// specialize portability
@@ -185,7 +182,9 @@ struct inclusive_scan<Params, BinOp, 1, false>
185
182
186
183
static scalar_t __call (scalar_t value)
187
184
{
188
- return inclusive_scan_impl<binop_t, 0 , config_t::SizeLog2-1 >::__call (value);
185
+ inclusive_scan_impl<binop_t> f_impl = inclusive_scan_impl<binop_t>::create (value);
186
+ unrolled_for_range<0 , config_t::SizeLog2>::template __call<inclusive_scan_impl<binop_t> >(f_impl);
187
+ return f_impl.value;
189
188
}
190
189
};
191
190
@@ -202,34 +201,30 @@ struct exclusive_scan<Params, BinOp, 1, false>
202
201
spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
203
202
204
203
scalar_t left = hlsl::mix (binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1 ), bool (glsl::gl_SubgroupInvocationID ()));
205
- spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
206
204
return inclusive_scan<Params, BinOp, 1 , false >::__call (left);
207
205
}
208
206
};
209
207
210
- template<class BinOp, uint16_t begin, uint16_t end >
208
+ template<class BinOp>
211
209
struct reduction_impl
212
210
{
213
211
using scalar_t = typename BinOp::type_t;
214
212
215
- static scalar_t __call (scalar_t value )
213
+ static reduction_impl<BinOp> create (scalar_t _value )
216
214
{
217
- BinOp op ;
218
- scalar_t new_value = op (glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<begin),value) ;
219
- return reduction_impl<BinOp,begin+ 1 ,end>:: __call (new_value) ;
215
+ reduction_impl< BinOp> retval ;
216
+ retval.value = _value ;
217
+ return retval ;
220
218
}
221
- };
222
219
223
- template<class BinOp, uint16_t end>
224
- struct reduction_impl<BinOp,end,end>
225
- {
226
- using scalar_t = typename BinOp::type_t;
227
-
228
- static scalar_t __call (scalar_t value)
220
+ template<uint16_t StepLog2>
221
+ void __call ()
229
222
{
230
223
BinOp op;
231
- return op (glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<end ),value);
224
+ value = op (glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<StepLog2 ),value);
232
225
}
226
+
227
+ scalar_t value;
233
228
};
234
229
235
230
template<class Params, class BinOp>
@@ -242,7 +237,9 @@ struct reduction<Params, BinOp, 1, false>
242
237
243
238
scalar_t operator ()(scalar_t value)
244
239
{
245
- return reduction_impl<binop_t, 0 , config_t::SizeLog2-1 >::__call (value);
240
+ reduction_impl<binop_t> f_impl = reduction_impl<binop_t>::create (value);
241
+ unrolled_for_range<0 , config_t::SizeLog2>::template __call<reduction_impl<binop_t> >(f_impl);
242
+ return f_impl.value;
246
243
}
247
244
};
248
245
0 commit comments