Skip to content

Commit 3a7ad67

Browse files
committed
added memory barrier to subgroup scans
1 parent 9b340a4 commit 3a7ad67

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ struct inclusive_scan<Params, BinOp, 1, false>
158158

159159
static scalar_t __call(scalar_t value)
160160
{
161+
// sync up each subgroup invocation so it runs in lockstep
162+
// not ideal because might not write to shared memory but a storage class is needed
163+
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
164+
161165
binop_t op;
162166
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
163167

@@ -185,6 +189,10 @@ struct exclusive_scan<Params, BinOp, 1, false>
185189

186190
scalar_t operator()(scalar_t value)
187191
{
192+
// sync up each subgroup invocation so it runs in lockstep
193+
// not ideal because might not write to shared memory but a storage class is needed
194+
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
195+
188196
scalar_t left = hlsl::mix(binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1), bool(glsl::gl_SubgroupInvocationID()));
189197
return inclusive_scan<Params, BinOp, 1, false>::__call(left);
190198
}
@@ -203,8 +211,11 @@ struct reduction<Params, BinOp, 1, false>
203211

204212
scalar_t operator()(scalar_t value)
205213
{
206-
binop_t op;
214+
// sync up each subgroup invocation so it runs in lockstep
215+
// not ideal because might not write to shared memory but a storage class is needed
216+
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
207217

218+
binop_t op;
208219
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
209220
[unroll]
210221
for (uint32_t i = 0; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)

0 commit comments

Comments
 (0)