Skip to content

Commit fc1bc51

Browse files
committed
removed redundant stuff, make config more readable
1 parent ce77b46 commit fc1bc51

File tree

3 files changed

+47
-40
lines changed

3 files changed

+47
-40
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,22 @@ struct ArithmeticConfiguration
5555
static_assert(VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
5656

5757
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
58-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
59-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
60-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
58+
using ItemsPerInvocation = typename items_per_invoc_t::ItemsPerInvocation;
59+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = tuple_element<0,ItemsPerInvocation>::type::value;
60+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = tuple_element<1,ItemsPerInvocation>::type::value;
61+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = tuple_element<2,ItemsPerInvocation>::type::value;
6162
static_assert(ItemsPerInvocation_2<=4, "4 level scan would have been needed with this config!");
6263

6364
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3,uint16_t,
6465
mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
6566
SubgroupSize*ItemsPerInvocation_1>::value;
6667
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value;
67-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = LevelInputCount_1 / ItemsPerInvocation_1;
68+
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualInvocationsAtLevel1 = LevelInputCount_1 / ItemsPerInvocation_1;
69+
70+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3,uint16_t,SubgroupSize-1,0>::value;
71+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_1 = conditional_value<LevelCount==3,uint16_t,VirtualInvocationsAtLevel1+__padding,SubgroupSize>::value;
72+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize,0>::value;
73+
using ChannelStride = tuple<integral_constant<uint16_t,__channelStride_1>,integral_constant<uint16_t,__channelStride_2> >;
6874

6975
// user specified the shared mem size of Scalars
7076
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
@@ -74,7 +80,6 @@ struct ArithmeticConfiguration
7480
0
7581
>::value + LevelInputCount_1
7682
>::value;
77-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3,uint16_t,SubgroupSize-1,0>::value;
7883

7984
static bool electLast()
8085
{
@@ -94,16 +99,21 @@ struct ArithmeticConfiguration
9499
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
95100
static uint16_t sharedStoreIndex(const uint16_t virtualSubgroupID)
96101
{
97-
uint16_t nextLevelInvocationCount;
98-
if (level == LevelCount-1)
99-
nextLevelInvocationCount = SubgroupSize;
100-
else
101-
nextLevelInvocationCount = __SubgroupsPerVirtualWorkgroup;
102+
const uint16_t ItemsPerNextInvocation = tuple_element<level,ItemsPerInvocation>::type::value;
103+
const uint16_t outChannel = virtualSubgroupID & (ItemsPerNextInvocation-uint16_t(1u));
104+
const uint16_t outInvocation = virtualSubgroupID/ItemsPerNextInvocation;
105+
const uint16_t localOffset = outChannel * tuple_element<level,ChannelStride>::type::value + outInvocation;
102106

103107
if (level==2)
104-
return LevelInputCount_1 + ((SubgroupSize-uint16_t(1u))*ItemsPerInvocation_1) + (virtualSubgroupID & (ItemsPerInvocation_2-uint16_t(1u))) * nextLevelInvocationCount + (virtualSubgroupID/ItemsPerInvocation_2);
108+
{
109+
const uint16_t baseOffset = LevelInputCount_1 + (SubgroupSize-uint16_t(1u)) * ItemsPerNextInvocation;
110+
return baseOffset + localOffset;
111+
}
105112
else
106-
return (virtualSubgroupID & (ItemsPerInvocation_1-uint16_t(1u))) * (nextLevelInvocationCount+__padding) + (virtualSubgroupID/ItemsPerInvocation_1) + virtualSubgroupID/(SubgroupSize*ItemsPerInvocation_1);
113+
{
114+
const uint16_t paddingOffset = virtualSubgroupID/(SubgroupSize*ItemsPerInvocation_1);
115+
return localOffset + paddingOffset;
116+
}
107117
}
108118

109119
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
@@ -117,16 +127,16 @@ struct ArithmeticConfiguration
117127
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
118128
static uint16_t sharedLoadIndex(const uint16_t invocationIndex, const uint16_t component)
119129
{
120-
uint16_t levelInvocationCount;
121-
if (level == LevelCount-1)
122-
levelInvocationCount = SubgroupSize;
123-
else
124-
levelInvocationCount = __SubgroupsPerVirtualWorkgroup;
130+
const uint16_t localOffset = component * tuple_element<level,ChannelStride>::type::value + invocationIndex;
131+
const uint16_t paddingOffset = invocationIndex/SubgroupSize;
125132

126133
if (level==2)
127-
return LevelInputCount_1 + ((SubgroupSize-uint16_t(1u))*ItemsPerInvocation_1) + component * levelInvocationCount + invocationIndex + invocationIndex/SubgroupSize;
134+
{
135+
const uint16_t baseOffset = LevelInputCount_1 + (SubgroupSize-uint16_t(1u)) * ItemsPerInvocation_1;
136+
return baseOffset + localOffset + paddingOffset;
137+
}
128138
else
129-
return component * (levelInvocationCount+__padding) + invocationIndex + invocationIndex/SubgroupSize;
139+
return localOffset + paddingOffset;
130140
}
131141
};
132142

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ struct reduce_level0
8484
using scalar_t = typename BinOp::type_t;
8585
using vector_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
8686

87-
template<class DataAccessor, class ScratchAccessor, class Params>
87+
template<class DataAccessor, class ScratchAccessor>
8888
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
8989
{
90+
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
91+
using params_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
92+
9093
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
9194
// level 0 scan
92-
subgroup2::reduction<Params> reduction0;
95+
subgroup2::reduction<params_t> reduction0;
9396
[unroll]
9497
for (uint16_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
9598
{
@@ -112,11 +115,14 @@ struct scan_level0
112115
using scalar_t = typename BinOp::type_t;
113116
using vector_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
114117

115-
template<class DataAccessor, class ScratchAccessor, class Params>
118+
template<class DataAccessor, class ScratchAccessor>
116119
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
117120
{
121+
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
122+
using params_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
123+
118124
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
119-
subgroup2::inclusive_scan<Params> inclusiveScan0;
125+
subgroup2::inclusive_scan<params_t> inclusiveScan0;
120126
// level 0 scan
121127
[unroll]
122128
for (uint16_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
@@ -147,11 +153,10 @@ struct reduce<Config, BinOp, 2, device_capabilities>
147153
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
148154
{
149155
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
150-
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
151156
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
152157
BinOp binop;
153158

154-
reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
159+
reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
155160

156161
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
157162
// level 1 scan
@@ -186,11 +191,10 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
186191
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
187192
{
188193
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
189-
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
190194
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
191195
BinOp binop;
192196

193-
scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
197+
scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
194198

195199
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
196200
// level 1 scan
@@ -216,11 +220,9 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
216220
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
217221

218222
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()-1u), idx);
219-
scalar_t left;
223+
scalar_t left = BinOp::identity;
220224
if (idx != 0 || glsl::gl_SubgroupID() != 0)
221225
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
222-
else
223-
left = BinOp::identity;
224226
if (Exclusive)
225227
{
226228
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
@@ -253,12 +255,11 @@ struct reduce<Config, BinOp, 3, device_capabilities>
253255
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
254256
{
255257
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
256-
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
257258
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
258259
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
259260
BinOp binop;
260261

261-
reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
262+
reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
262263

263264
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
264265
// level 1 scan
@@ -310,12 +311,11 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
310311
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
311312
{
312313
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
313-
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
314314
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
315315
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
316316
BinOp binop;
317317

318-
scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
318+
scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor>(dataAccessor, scratchAccessor);
319319

320320
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
321321
// level 1 scan
@@ -357,12 +357,10 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
357357
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
358358
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), lv1_val[i]);
359359

360-
scalar_t lv2_scan;
360+
scalar_t lv2_scan = BinOp::identity;
361361
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()-1u));
362362
if (glsl::gl_SubgroupID() != 0)
363363
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex, lv2_scan);
364-
else
365-
lv2_scan = BinOp::identity;
366364

367365
[unroll]
368366
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
@@ -378,11 +376,9 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
378376
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
379377

380378
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()-1u), idx);
381-
scalar_t left;
379+
scalar_t left = BinOp::identity;
382380
if (idx != 0 || glsl::gl_SubgroupID() != 0)
383381
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
384-
else
385-
left = BinOp::identity;
386382
if (Exclusive)
387383
{
388384
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));

src/nbl/builtin/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/anisotropi
369369
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/loadable_image.hlsl")
370370
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/mip_mapped.hlsl")
371371
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/storable_image.hlsl")
372+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/generic_shared_data.hlsl")
372373
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/fft.hlsl")
373374
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/workgroup_arithmetic.hlsl")
374375
#tgmath

0 commit comments

Comments
 (0)