Skip to content

Commit 10b7f50

Browse files
committed
fix some bugs, readability fix
1 parent fc1bc51 commit 10b7f50

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ struct ArithmeticConfiguration
6868
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualInvocationsAtLevel1 = LevelInputCount_1 / ItemsPerInvocation_1;
6969

7070
NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3,uint16_t,SubgroupSize-1,0>::value;
71-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_1 = conditional_value<LevelCount==3,uint16_t,VirtualInvocationsAtLevel1+__padding,SubgroupSize>::value;
71+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_1 = conditional_value<LevelCount==3,uint16_t,VirtualInvocationsAtLevel1,SubgroupSize>::value + __padding;
7272
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize,0>::value;
73-
using ChannelStride = tuple<integral_constant<uint16_t,__channelStride_1>,integral_constant<uint16_t,__channelStride_2> >;
73+
using ChannelStride = tuple<integral_constant<uint16_t,__padding>,integral_constant<uint16_t,__channelStride_1>,integral_constant<uint16_t,__channelStride_2> >; // we don't use stride 0
7474

7575
// user specified the shared mem size of Scalars
7676
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
@@ -101,17 +101,17 @@ struct ArithmeticConfiguration
101101
{
102102
const uint16_t ItemsPerNextInvocation = tuple_element<level,ItemsPerInvocation>::type::value;
103103
const uint16_t outChannel = virtualSubgroupID & (ItemsPerNextInvocation-uint16_t(1u));
104-
const uint16_t outInvocation = virtualSubgroupID/ItemsPerNextInvocation;
104+
const uint16_t outInvocation = virtualSubgroupID / ItemsPerNextInvocation;
105105
const uint16_t localOffset = outChannel * tuple_element<level,ChannelStride>::type::value + outInvocation;
106106

107107
if (level==2)
108108
{
109-
const uint16_t baseOffset = LevelInputCount_1 + (SubgroupSize-uint16_t(1u)) * ItemsPerNextInvocation;
109+
const uint16_t baseOffset = LevelInputCount_1 + (SubgroupSize - uint16_t(1u)) * ItemsPerInvocation_1;
110110
return baseOffset + localOffset;
111111
}
112112
else
113113
{
114-
const uint16_t paddingOffset = virtualSubgroupID/(SubgroupSize*ItemsPerInvocation_1);
114+
const uint16_t paddingOffset = virtualSubgroupID / (SubgroupSize * ItemsPerInvocation_1);
115115
return localOffset + paddingOffset;
116116
}
117117
}
@@ -128,11 +128,11 @@ struct ArithmeticConfiguration
128128
static uint16_t sharedLoadIndex(const uint16_t invocationIndex, const uint16_t component)
129129
{
130130
const uint16_t localOffset = component * tuple_element<level,ChannelStride>::type::value + invocationIndex;
131-
const uint16_t paddingOffset = invocationIndex/SubgroupSize;
131+
const uint16_t paddingOffset = invocationIndex / SubgroupSize;
132132

133133
if (level==2)
134134
{
135-
const uint16_t baseOffset = LevelInputCount_1 + (SubgroupSize-uint16_t(1u)) * ItemsPerInvocation_1;
135+
const uint16_t baseOffset = LevelInputCount_1 + (SubgroupSize - uint16_t(1u)) * ItemsPerInvocation_1;
136136
return baseOffset + localOffset + paddingOffset;
137137
}
138138
else

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,15 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
337337
subgroup2::inclusive_scan<params_lv2_t> inclusiveScan2;
338338
if (glsl::gl_SubgroupID() == 0)
339339
{
340-
const uint16_t one = uint16_t(1u);
340+
const uint16_t lastChannel = Config::ItemsPerInvocation_1 - uint16_t(1u);
341341
vector_lv2_t lv2_val;
342342
[unroll]
343343
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
344-
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>((invocationIndex*Config::ItemsPerInvocation_2+i+one)*Config::SubgroupSize-one, Config::ItemsPerInvocation_1-one),lv2_val[i]);
344+
{
345+
const uint16_t inputSubgroupID = invocationIndex * Config::ItemsPerInvocation_2 + i;
346+
const uint16_t inputSubgroupLastInvocation = inputSubgroupID * Config::SubgroupSize + (Config::SubgroupSize - uint16_t(1u));
347+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(inputSubgroupLastInvocation, lastChannel),lv2_val[i]);
348+
}
345349
lv2_val = inclusiveScan2(lv2_val);
346350
[unroll]
347351
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)

0 commit comments

Comments
 (0)