Skip to content

Commit 9a797a0

Browse files
Merge pull request #893 from Devsh-Graphics-Programming/minor-subgroup-sync-adjustment
change to subgroup control barrier
2 parents 5dcd5c3 + 9c7239b commit 9a797a0

File tree

2 files changed

+78
-42
lines changed

2 files changed

+78
-42
lines changed

include/nbl/builtin/hlsl/algorithm.hlsl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ NBL_CONSTEXPR_INLINE_FUNC void swap(NBL_REF_ARG(T) lhs, NBL_REF_ARG(T) rhs)
8888
}
8989

9090

91-
#ifdef __HLSL_VERSION
92-
9391
namespace impl
9492
{
9593

@@ -224,7 +222,26 @@ uint upper_bound(NBL_REF_ARG(Accessor) accessor, const uint begin, const uint en
224222
return impl::upper_bound<Accessor,typename Accessor::value_type>(accessor,begin,end,value);
225223
}
226224

227-
#endif
225+
226+
template<int begin, int end>
227+
struct unrolled_for_range;
228+
template<int end>
229+
struct unrolled_for_range<end,end>
230+
{
231+
template<typename F>
232+
static void __call(NBL_REF_ARG(F) f) {}
233+
};
234+
template<int begin, int end>
235+
struct unrolled_for_range
236+
{
237+
template<typename F>
238+
static void __call(NBL_REF_ARG(F) f)
239+
{
240+
f.template __call<begin>();
241+
unrolled_for_range<begin+1,end>::template __call<F>(f);
242+
}
243+
};
244+
228245
}
229246
}
230247

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
1111
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1212

13+
#include "nbl/builtin/hlsl/algorithm.hlsl"
1314
#include "nbl/builtin/hlsl/functional.hlsl"
1415
#include "nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl"
1516

@@ -138,45 +139,52 @@ SPECIALIZE_ALL(maximum,Max);
138139
#undef SPECIALIZE_ALL
139140
#undef SPECIALIZE
140141

142+
template<class BinOp>
143+
struct inclusive_scan_impl
144+
{
145+
using scalar_t = typename BinOp::type_t;
146+
147+
static inclusive_scan_impl<BinOp> create(scalar_t _value)
148+
{
149+
inclusive_scan_impl<BinOp> retval;
150+
retval.value = _value;
151+
retval.subgroupInvocation = glsl::gl_SubgroupInvocationID();
152+
return retval;
153+
}
154+
155+
template<uint16_t StepLog2>
156+
void __call()
157+
{
158+
BinOp op;
159+
const uint32_t step = 1u << StepLog2;
160+
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
161+
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
162+
value = op(value, hlsl::mix(rhs, BinOp::identity, subgroupInvocation < step));
163+
}
164+
165+
scalar_t value;
166+
uint32_t subgroupInvocation;
167+
};
168+
141169
// specialize portability
142170
template<class Params, class BinOp>
143171
struct inclusive_scan<Params, BinOp, 1, false>
144172
{
145173
using type_t = typename Params::type_t;
146174
using scalar_t = typename Params::scalar_t;
147175
using binop_t = typename Params::binop_t;
148-
// assert T == scalar type, binop::type == T
149176
using config_t = typename Params::config_t;
150177

151-
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
152-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
153-
154178
scalar_t operator()(scalar_t value)
155179
{
156180
return __call(value);
157181
}
158182

159183
static scalar_t __call(scalar_t value)
160184
{
161-
// sync up each subgroup invocation so it runs in lockstep
162-
// not ideal because might not write to shared memory but a storage class is needed
163-
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
164-
165-
binop_t op;
166-
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
167-
168-
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
169-
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));
170-
171-
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
172-
[unroll]
173-
for (uint32_t i = 1; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
174-
{
175-
const uint32_t step = 1u << i;
176-
rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
177-
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
178-
}
179-
return value;
185+
inclusive_scan_impl<binop_t> f_impl = inclusive_scan_impl<binop_t>::create(value);
186+
unrolled_for_range<0, config_t::SizeLog2>::template __call<inclusive_scan_impl<binop_t> >(f_impl);
187+
return f_impl.value;
180188
}
181189
};
182190

@@ -190,14 +198,36 @@ struct exclusive_scan<Params, BinOp, 1, false>
190198
scalar_t operator()(scalar_t value)
191199
{
192200
// sync up each subgroup invocation so it runs in lockstep
193-
// not ideal because might not write to shared memory but a storage class is needed
194-
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
201+
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
195202

196203
scalar_t left = hlsl::mix(binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1), bool(glsl::gl_SubgroupInvocationID()));
197204
return inclusive_scan<Params, BinOp, 1, false>::__call(left);
198205
}
199206
};
200207

208+
template<class BinOp>
209+
struct reduction_impl
210+
{
211+
using scalar_t = typename BinOp::type_t;
212+
213+
static reduction_impl<BinOp> create(scalar_t _value)
214+
{
215+
reduction_impl<BinOp> retval;
216+
retval.value = _value;
217+
return retval;
218+
}
219+
220+
template<uint16_t StepLog2>
221+
void __call()
222+
{
223+
BinOp op;
224+
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
225+
value = op(glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<StepLog2),value);
226+
}
227+
228+
scalar_t value;
229+
};
230+
201231
template<class Params, class BinOp>
202232
struct reduction<Params, BinOp, 1, false>
203233
{
@@ -206,22 +236,11 @@ struct reduction<Params, BinOp, 1, false>
206236
using binop_t = typename Params::binop_t;
207237
using config_t = typename Params::config_t;
208238

209-
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
210-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
211-
212239
scalar_t operator()(scalar_t value)
213240
{
214-
// sync up each subgroup invocation so it runs in lockstep
215-
// not ideal because might not write to shared memory but a storage class is needed
216-
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
217-
218-
binop_t op;
219-
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
220-
[unroll]
221-
for (uint32_t i = 0; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
222-
value = op(glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
223-
224-
return value;
241+
reduction_impl<binop_t> f_impl = reduction_impl<binop_t>::create(value);
242+
unrolled_for_range<0, config_t::SizeLog2>::template __call<reduction_impl<binop_t> >(f_impl);
243+
return f_impl.value;
225244
}
226245
};
227246

0 commit comments

Comments
 (0)