Skip to content

Commit b627f96

Browse files
committed
use unrolled for, added missing barrier
1 parent 228e297 commit b627f96

File tree

2 files changed

+32
-35
lines changed

2 files changed

+32
-35
lines changed

include/nbl/builtin/hlsl/algorithm.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ struct unrolled_for_range
240240
static void __call(inout F f)
241241
{
242242
f.template __call<begin>();
243-
unrolled_for<begin+1,end>::template __call<F>(f);
243+
unrolled_for_range<begin+1,end>::template __call<F>(f);
244244
}
245245
};
246246

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
1111
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1212

13+
#include "nbl/builtin/hlsl/algorithm.hlsl"
1314
#include "nbl/builtin/hlsl/functional.hlsl"
1415
#include "nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl"
1516

@@ -138,35 +139,31 @@ SPECIALIZE_ALL(maximum,Max);
138139
#undef SPECIALIZE_ALL
139140
#undef SPECIALIZE
140141

141-
template<class BinOp, uint16_t begin, uint16_t end>
142+
template<class BinOp>
142143
struct inclusive_scan_impl
143144
{
144145
using scalar_t = typename BinOp::type_t;
145146

146-
static scalar_t __call(scalar_t value)
147+
static inclusive_scan_impl<BinOp> create(scalar_t _value)
147148
{
148-
BinOp op;
149-
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
150-
const uint32_t step = 1u << begin;
151-
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
152-
scalar_t new_value = op(value, hlsl::mix(rhs, BinOp::identity, subgroupInvocation < step));
153-
return inclusive_scan_impl<BinOp,begin+1,end>::__call(new_value);
149+
inclusive_scan_impl<BinOp> retval;
150+
retval.value = _value;
151+
retval.subgroupInvocation = glsl::gl_SubgroupInvocationID();
152+
return retval;
154153
}
155-
};
156154

157-
template<class BinOp, uint16_t end>
158-
struct inclusive_scan_impl<BinOp,end,end>
159-
{
160-
using scalar_t = typename BinOp::type_t;
161-
162-
static scalar_t __call(scalar_t value)
155+
template<uint16_t StepLog2>
156+
void __call()
163157
{
164158
BinOp op;
165-
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
166-
const uint32_t step = 1u << end;
159+
const uint32_t step = 1u << StepLog2;
160+
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
167161
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
168-
return op(value, hlsl::mix(rhs, BinOp::identity, subgroupInvocation < step));
162+
value = op(value, hlsl::mix(rhs, BinOp::identity, subgroupInvocation < step));
169163
}
164+
165+
scalar_t value;
166+
uint32_t subgroupInvocation;
170167
};
171168

172169
// specialize portability
@@ -185,7 +182,9 @@ struct inclusive_scan<Params, BinOp, 1, false>
185182

186183
static scalar_t __call(scalar_t value)
187184
{
188-
return inclusive_scan_impl<binop_t, 0, config_t::SizeLog2-1>::__call(value);
185+
inclusive_scan_impl<binop_t> f_impl = inclusive_scan_impl<binop_t>::create(value);
186+
unrolled_for_range<0, config_t::SizeLog2>::template __call<inclusive_scan_impl<binop_t> >(f_impl);
187+
return f_impl.value;
189188
}
190189
};
191190

@@ -202,34 +201,30 @@ struct exclusive_scan<Params, BinOp, 1, false>
202201
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
203202

204203
scalar_t left = hlsl::mix(binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1), bool(glsl::gl_SubgroupInvocationID()));
205-
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
206204
return inclusive_scan<Params, BinOp, 1, false>::__call(left);
207205
}
208206
};
209207

210-
template<class BinOp, uint16_t begin, uint16_t end>
208+
template<class BinOp>
211209
struct reduction_impl
212210
{
213211
using scalar_t = typename BinOp::type_t;
214212

215-
static scalar_t __call(scalar_t value)
213+
static reduction_impl<BinOp> create(scalar_t _value)
216214
{
217-
BinOp op;
218-
scalar_t new_value = op(glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<begin),value);
219-
return reduction_impl<BinOp,begin+1,end>::__call(new_value);
215+
reduction_impl<BinOp> retval;
216+
retval.value = _value;
217+
return retval;
220218
}
221-
};
222219

223-
template<class BinOp, uint16_t end>
224-
struct reduction_impl<BinOp,end,end>
225-
{
226-
using scalar_t = typename BinOp::type_t;
227-
228-
static scalar_t __call(scalar_t value)
220+
template<uint16_t StepLog2>
221+
void __call()
229222
{
230223
BinOp op;
231-
return op(glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<end),value);
224+
value = op(glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<StepLog2),value);
232225
}
226+
227+
scalar_t value;
233228
};
234229

235230
template<class Params, class BinOp>
@@ -242,7 +237,9 @@ struct reduction<Params, BinOp, 1, false>
242237

243238
scalar_t operator()(scalar_t value)
244239
{
245-
return reduction_impl<binop_t, 0, config_t::SizeLog2-1>::__call(value);
240+
reduction_impl<binop_t> f_impl = reduction_impl<binop_t>::create(value);
241+
unrolled_for_range<0, config_t::SizeLog2>::template __call<reduction_impl<binop_t> >(f_impl);
242+
return f_impl.value;
246243
}
247244
};
248245

0 commit comments

Comments
 (0)