@@ -158,6 +158,10 @@ struct inclusive_scan<Params, BinOp, 1, false>
158
158
159
159
static scalar_t __call (scalar_t value)
160
160
{
161
+ // sync up each subgroup invocation so it runs in lockstep
162
+ // not ideal because might not write to shared memory but a storage class is needed
163
+ spirv::memoryBarrier (spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
164
+
161
165
binop_t op;
162
166
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
163
167
@@ -185,6 +189,10 @@ struct exclusive_scan<Params, BinOp, 1, false>
185
189
186
190
scalar_t operator ()(scalar_t value)
187
191
{
192
+ // sync up each subgroup invocation so it runs in lockstep
193
+ // not ideal because might not write to shared memory but a storage class is needed
194
+ spirv::memoryBarrier (spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
195
+
188
196
scalar_t left = hlsl::mix (binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1 ), bool (glsl::gl_SubgroupInvocationID ()));
189
197
return inclusive_scan<Params, BinOp, 1 , false >::__call (left);
190
198
}
@@ -203,8 +211,11 @@ struct reduction<Params, BinOp, 1, false>
203
211
204
212
scalar_t operator ()(scalar_t value)
205
213
{
206
- binop_t op;
214
+ // sync up each subgroup invocation so it runs in lockstep
215
+ // not ideal because might not write to shared memory but a storage class is needed
216
+ spirv::memoryBarrier (spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
207
217
218
+ binop_t op;
208
219
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
209
220
[unroll]
210
221
for (uint32_t i = 0 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
0 commit comments