@@ -138,44 +138,54 @@ SPECIALIZE_ALL(maximum,Max);
138
138
#undef SPECIALIZE_ALL
139
139
#undef SPECIALIZE
140
140
141
+ template<class BinOp, uint16_t begin, uint16_t end>
142
+ struct inclusive_scan_impl
143
+ {
144
+ using scalar_t = typename BinOp::type_t;
145
+
146
+ static scalar_t __call (scalar_t value)
147
+ {
148
+ BinOp op;
149
+ const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
150
+ const uint32_t step = 1u << begin;
151
+ scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
152
+ scalar_t new_value = op (value, hlsl::mix (rhs, BinOp::identity, subgroupInvocation < step));
153
+ return inclusive_scan_impl<BinOp,begin+1 ,end>::__call (new_value);
154
+ }
155
+ };
156
+
157
+ template<class BinOp, uint16_t end>
158
+ struct inclusive_scan_impl<BinOp,end,end>
159
+ {
160
+ using scalar_t = typename BinOp::type_t;
161
+
162
+ static scalar_t __call (scalar_t value)
163
+ {
164
+ BinOp op;
165
+ const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
166
+ const uint32_t step = 1u << end;
167
+ scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
168
+ return op (value, hlsl::mix (rhs, BinOp::identity, subgroupInvocation < step));
169
+ }
170
+ };
171
+
141
172
// specialize portability
142
173
template<class Params, class BinOp>
143
174
struct inclusive_scan<Params, BinOp, 1 , false >
144
175
{
145
176
using type_t = typename Params::type_t;
146
177
using scalar_t = typename Params::scalar_t;
147
178
using binop_t = typename Params::binop_t;
148
- // assert T == scalar type, binop::type == T
149
179
using config_t = typename Params::config_t;
150
180
151
- // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
152
- // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
153
-
154
181
scalar_t operator ()(scalar_t value)
155
182
{
156
183
return __call (value);
157
184
}
158
185
159
186
static scalar_t __call (scalar_t value)
160
187
{
161
- // sync up each subgroup invocation so it runs in lockstep
162
- spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
163
-
164
- binop_t op;
165
- const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
166
-
167
- scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
168
- value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < 1u));
169
-
170
- const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
171
- [unroll]
172
- for (uint32_t i = 1 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
173
- {
174
- const uint32_t step = 1u << i;
175
- rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
176
- value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < step));
177
- }
178
- return value;
188
+ return inclusive_scan_impl<binop_t, 0 , config_t::SizeLog2-1 >::__call (value);
179
189
}
180
190
};
181
191
@@ -192,10 +202,36 @@ struct exclusive_scan<Params, BinOp, 1, false>
192
202
spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
193
203
194
204
scalar_t left = hlsl::mix (binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1 ), bool (glsl::gl_SubgroupInvocationID ()));
205
+ spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
195
206
return inclusive_scan<Params, BinOp, 1 , false >::__call (left);
196
207
}
197
208
};
198
209
210
+ template<class BinOp, uint16_t begin, uint16_t end>
211
+ struct reduction_impl
212
+ {
213
+ using scalar_t = typename BinOp::type_t;
214
+
215
+ static scalar_t __call (scalar_t value)
216
+ {
217
+ BinOp op;
218
+ scalar_t new_value = op (glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<begin),value);
219
+ return reduction_impl<BinOp,begin+1 ,end>::__call (new_value);
220
+ }
221
+ };
222
+
223
+ template<class BinOp, uint16_t end>
224
+ struct reduction_impl<BinOp,end,end>
225
+ {
226
+ using scalar_t = typename BinOp::type_t;
227
+
228
+ static scalar_t __call (scalar_t value)
229
+ {
230
+ BinOp op;
231
+ return op (glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<end),value);
232
+ }
233
+ };
234
+
199
235
template<class Params, class BinOp>
200
236
struct reduction<Params, BinOp, 1 , false >
201
237
{
@@ -204,21 +240,9 @@ struct reduction<Params, BinOp, 1, false>
204
240
using binop_t = typename Params::binop_t;
205
241
using config_t = typename Params::config_t;
206
242
207
- // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
208
- // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
209
-
210
243
scalar_t operator ()(scalar_t value)
211
244
{
212
- // sync up each subgroup invocation so it runs in lockstep
213
- spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
214
-
215
- binop_t op;
216
- const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
217
- [unroll]
218
- for (uint32_t i = 0 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
219
- value = op (glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
220
-
221
- return value;
245
+ return reduction_impl<binop_t, 0 , config_t::SizeLog2-1 >::__call (value);
222
246
}
223
247
};
224
248
0 commit comments