Skip to content

Commit 4b7ac8f

Browse files
committed
use template recursions instead
1 parent dd5d963 commit 4b7ac8f

File tree

1 file changed

+59
-35
lines changed

1 file changed

+59
-35
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -138,44 +138,54 @@ SPECIALIZE_ALL(maximum,Max);
138138
#undef SPECIALIZE_ALL
139139
#undef SPECIALIZE
140140

141+
template<class BinOp, uint16_t begin, uint16_t end>
142+
struct inclusive_scan_impl
143+
{
144+
using scalar_t = typename BinOp::type_t;
145+
146+
static scalar_t __call(scalar_t value)
147+
{
148+
BinOp op;
149+
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
150+
const uint32_t step = 1u << begin;
151+
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
152+
scalar_t new_value = op(value, hlsl::mix(rhs, BinOp::identity, subgroupInvocation < step));
153+
return inclusive_scan_impl<BinOp,begin+1,end>::__call(new_value);
154+
}
155+
};
156+
157+
template<class BinOp, uint16_t end>
158+
struct inclusive_scan_impl<BinOp,end,end>
159+
{
160+
using scalar_t = typename BinOp::type_t;
161+
162+
static scalar_t __call(scalar_t value)
163+
{
164+
BinOp op;
165+
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
166+
const uint32_t step = 1u << end;
167+
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
168+
return op(value, hlsl::mix(rhs, BinOp::identity, subgroupInvocation < step));
169+
}
170+
};
171+
141172
// specialize portability
142173
template<class Params, class BinOp>
143174
struct inclusive_scan<Params, BinOp, 1, false>
144175
{
145176
using type_t = typename Params::type_t;
146177
using scalar_t = typename Params::scalar_t;
147178
using binop_t = typename Params::binop_t;
148-
// assert T == scalar type, binop::type == T
149179
using config_t = typename Params::config_t;
150180

151-
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
152-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
153-
154181
scalar_t operator()(scalar_t value)
155182
{
156183
return __call(value);
157184
}
158185

159186
static scalar_t __call(scalar_t value)
160187
{
161-
// sync up each subgroup invocation so it runs in lockstep
162-
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
163-
164-
binop_t op;
165-
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
166-
167-
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
168-
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));
169-
170-
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
171-
[unroll]
172-
for (uint32_t i = 1; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
173-
{
174-
const uint32_t step = 1u << i;
175-
rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
176-
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
177-
}
178-
return value;
188+
return inclusive_scan_impl<binop_t, 0, config_t::SizeLog2-1>::__call(value);
179189
}
180190
};
181191

@@ -192,10 +202,36 @@ struct exclusive_scan<Params, BinOp, 1, false>
192202
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
193203

194204
scalar_t left = hlsl::mix(binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1), bool(glsl::gl_SubgroupInvocationID()));
205+
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
195206
return inclusive_scan<Params, BinOp, 1, false>::__call(left);
196207
}
197208
};
198209

210+
template<class BinOp, uint16_t begin, uint16_t end>
211+
struct reduction_impl
212+
{
213+
using scalar_t = typename BinOp::type_t;
214+
215+
static scalar_t __call(scalar_t value)
216+
{
217+
BinOp op;
218+
scalar_t new_value = op(glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<begin),value);
219+
return reduction_impl<BinOp,begin+1,end>::__call(new_value);
220+
}
221+
};
222+
223+
template<class BinOp, uint16_t end>
224+
struct reduction_impl<BinOp,end,end>
225+
{
226+
using scalar_t = typename BinOp::type_t;
227+
228+
static scalar_t __call(scalar_t value)
229+
{
230+
BinOp op;
231+
return op(glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<end),value);
232+
}
233+
};
234+
199235
template<class Params, class BinOp>
200236
struct reduction<Params, BinOp, 1, false>
201237
{
@@ -204,21 +240,9 @@ struct reduction<Params, BinOp, 1, false>
204240
using binop_t = typename Params::binop_t;
205241
using config_t = typename Params::config_t;
206242

207-
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
208-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
209-
210243
scalar_t operator()(scalar_t value)
211244
{
212-
// sync up each subgroup invocation so it runs in lockstep
213-
spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
214-
215-
binop_t op;
216-
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
217-
[unroll]
218-
for (uint32_t i = 0; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
219-
value = op(glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
220-
221-
return value;
245+
return reduction_impl<binop_t, 0, config_t::SizeLog2-1>::__call(value);
222246
}
223247
};
224248

0 commit comments

Comments
 (0)