Skip to content

Commit b65dc7f

Browse files
committed
box_sampler: rework ETC_REPEAT
Signed-off-by: Ali Cheraghi <alichraghi@proton.me>
1 parent d2adf17 commit b65dc7f

File tree

2 files changed

+35
-52
lines changed

2 files changed

+35
-52
lines changed

include/nbl/builtin/hlsl/prefix_sum_blur/blur.hlsl

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace hlsl
1515
namespace prefix_sum_blur
1616
{
1717

18-
// Prefix-Sum Blur using SAT (Summed Area Table) technique
18+
// Prefix-Sum Blur using SAT (Summed Area Table) technique.
19+
// `scanScract` and `_sampler.prefixSumAccessor` must not to alias.
1920
template<
2021
typename DataAccessor,
2122
typename ScanSharedAccessor,
@@ -24,27 +25,6 @@ template<
2425
class device_capabilities=void> // TODO: define concepts for the Box1D and apply constraints
2526
struct Blur1D
2627
{
27-
// TODO: Generalize later on when Francesco enforces accessor-concepts in `workgroup` and adds a `SharedMemoryAccessor` concept
28-
struct ScanSharedAccessorWrapper
29-
{
30-
void get(const uint16_t ix, NBL_REF_ARG(float32_t) val)
31-
{
32-
val = base.template get<float32_t, uint16_t>(ix);
33-
}
34-
35-
void set(const uint16_t ix, const float32_t val)
36-
{
37-
base.template set<float32_t, uint16_t>(ix, val);
38-
}
39-
40-
void workgroupExecutionAndMemoryBarrier()
41-
{
42-
base.workgroupExecutionAndMemoryBarrier();
43-
}
44-
45-
ScanSharedAccessor base;
46-
};
47-
4828
void operator()(
4929
NBL_REF_ARG(DataAccessor) data,
5030
NBL_REF_ARG(ScanSharedAccessor) scanScratch,
@@ -67,17 +47,9 @@ struct Blur1D
6747
if (localInvocationIndex == 0)
6848
input += _sampler.prefixSumAccessor.template get<float32_t>(baseIx - 1);
6949
}
70-
// need to copy-in / copy-out the accessor cause no references in HLSL - yay!
71-
ScanSharedAccessorWrapper scanScratchWrapper;
72-
scanScratchWrapper.base = scanScratch;
73-
const float32_t sum = workgroup::inclusive_scan<plus<float32_t>, WorkgroupSize, device_capabilities>::template __call(input, scanScratchWrapper);
74-
scanScratch = scanScratchWrapper.base;
50+
const float32_t sum = workgroup::inclusive_scan<plus<float32_t>, WorkgroupSize, device_capabilities>::template __call(input, scanScratch);
7551
// loop increment
7652
baseIx += WorkgroupSize;
77-
// if doing the last prefix sum, we need to barrier to stop aliasing of temporary scratch for `inclusive_scan` and our scanline
78-
// TODO: might be worth adding a non-aliased mode as NSight says nr 1 hotspot is barrier waiting in this code
79-
if (end + ScanSharedAccessor::Size > Sampler::prefix_sum_accessor_t::Size)
80-
_sampler.prefixSumAccessor.workgroupExecutionAndMemoryBarrier();
8153
// save prefix sum results
8254
if (ix < end)
8355
_sampler.prefixSumAccessor.template set<float32_t>(ix, sum);

include/nbl/builtin/hlsl/prefix_sum_blur/box_sampler.hlsl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace hlsl
1212
namespace prefix_sum_blur
1313
{
1414

15+
// Requires an *inclusive* prefix sum
1516
template<typename PrefixSumAccessor, typename T>
1617
struct BoxSampler
1718
{
@@ -20,7 +21,6 @@ struct BoxSampler
2021
PrefixSumAccessor prefixSumAccessor;
2122
uint16_t wrapMode;
2223
uint16_t linearSize;
23-
T normalizationFactor;
2424

2525
T operator()(float32_t ix, float32_t radius, float32_t borderColor)
2626
{
@@ -33,7 +33,8 @@ struct BoxSampler
3333
const int32_t leftFlIdx = (int32_t)floor(leftIdx);
3434
const int32_t leftClIdx = (int32_t)ceil(leftIdx);
3535

36-
assert(linearSize > 1);
36+
assert(linearSize > 1 && radius >= 0);
37+
assert(borderColor >= 0 && borderColor <= 1);
3738

3839
T result = 0;
3940
if (rightClIdx < linearSize)
@@ -45,10 +46,15 @@ struct BoxSampler
4546
switch (wrapMode) {
4647
case ETC_REPEAT:
4748
{
49+
const uint32_t flooredMod = rightFlIdx % linearSize;
50+
const uint32_t ceiledMod = rightClIdx % linearSize;
4851
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
49-
const T floored = prefixSumAccessor.template get<T, uint32_t>(rightFlIdx % linearSize) + last;
50-
const T ceiled = prefixSumAccessor.template get<T, uint32_t>(rightClIdx % linearSize) + last;
51-
result += lerp(floored, ceiled, alpha);
52+
const T periodicOffset = (T(rightFlIdx) / linearSize) * last;
53+
const T floored = prefixSumAccessor.template get<T, uint32_t>(flooredMod);
54+
T ceiled = prefixSumAccessor.template get<T, uint32_t>(ceiledMod);
55+
if (flooredMod == lastIdx && ceiledMod == 0)
56+
ceiled += last;
57+
result += lerp(floored, ceiled, alpha) + periodicOffset;
5258
break;
5359
}
5460
case ETC_CLAMP_TO_BORDER:
@@ -114,10 +120,15 @@ struct BoxSampler
114120
switch (wrapMode) {
115121
case ETC_REPEAT:
116122
{
123+
const uint32_t flooredMod = (linearSize + leftFlIdx) % linearSize;
124+
const uint32_t ceiledMod = (linearSize + leftClIdx) % linearSize;
117125
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
118-
const T floored = prefixSumAccessor.template get<T, uint32_t>((lastIdx + leftFlIdx) % linearSize) + floor(T(leftFlIdx) / linearSize) * last;
119-
const T ceiled = prefixSumAccessor.template get<T, uint32_t>((lastIdx + leftClIdx) % linearSize) + floor(T(leftClIdx) / linearSize) * last;
120-
result -= lerp(floored, ceiled, alpha);
126+
const T periodicOffset = (T(linearSize + leftClIdx) / T(linearSize)) * last;
127+
const T floored = prefixSumAccessor.template get<T, uint32_t>(flooredMod);
128+
T ceiled = prefixSumAccessor.template get<T, uint32_t>(ceiledMod);
129+
if (flooredMod == lastIdx && ceiledMod == 0)
130+
ceiled += last;
131+
result -= lerp(floored, ceiled, alpha) - periodicOffset;
121132
break;
122133
}
123134
case ETC_CLAMP_TO_BORDER:
@@ -127,36 +138,36 @@ struct BoxSampler
127138
}
128139
case ETC_CLAMP_TO_EDGE:
129140
{
130-
result -= (1 - abs(leftIdx)) * prefixSumAccessor.template get<T, uint32_t>(0);
141+
result -= (leftIdx + 1) * prefixSumAccessor.template get<T, uint32_t>(0);
131142
break;
132143
}
133144
case ETC_MIRROR:
134145
{
135146
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
136147
T floored, ceiled;
137148

138-
if (abs(leftFlIdx + 1) % (2 * linearSize) == 0)
139-
floored = -(abs(leftFlIdx + 1) / linearSize) * last;
149+
if (abs(leftFlIdx) % (2 * linearSize) == 0)
150+
floored = -(abs(leftFlIdx) / linearSize) * last;
140151
else
141152
{
142-
const uint32_t period = uint32_t(ceil(float32_t(abs(leftFlIdx + 1)) / linearSize));
153+
const uint32_t period = uint32_t(ceil(float32_t(abs(leftFlIdx)) / linearSize));
143154
if ((period & 0x1u) == 1)
144-
floored = -(period - 1) * last - prefixSumAccessor.template get<T, uint32_t>((abs(leftFlIdx + 1) - 1) % linearSize);
155+
floored = -(period - 1) * last - prefixSumAccessor.template get<T, uint32_t>((abs(leftFlIdx) - 1) % linearSize);
145156
else
146-
floored = -(period - 1) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftFlIdx + 1) % linearSize - 1));
157+
floored = -(period - 1) * last - (last - prefixSumAccessor.template get<T, uint32_t>(leftFlIdx % linearSize - 1));
147158
}
148159

149160
if (leftClIdx == 0) // Special case, wouldn't be possible for `floored` above
150161
ceiled = 0;
151-
else if (abs(leftClIdx + 1) % (2 * linearSize) == 0)
152-
ceiled = -(abs(leftClIdx + 1) / linearSize) * last;
162+
else if (abs(leftClIdx) % (2 * linearSize) == 0)
163+
ceiled = -(abs(leftClIdx) / linearSize) * last;
153164
else
154165
{
155-
const uint32_t period = uint32_t(ceil(float32_t(abs(leftClIdx + 1)) / linearSize));
166+
const uint32_t period = uint32_t(ceil(float32_t(abs(leftClIdx)) / linearSize));
156167
if ((period & 0x1u) == 1)
157-
ceiled = -(period - 1) * last - prefixSumAccessor.template get<T, uint32_t>((abs(leftClIdx + 1) - 1) % linearSize);
168+
ceiled = -(period - 1) * last - prefixSumAccessor.template get<T, uint32_t>((abs(leftClIdx) - 1) % linearSize);
158169
else
159-
ceiled = -(period - 1) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftClIdx + 1) % linearSize - 1));
170+
ceiled = -(period - 1) * last - (last - prefixSumAccessor.template get<T, uint32_t>(leftClIdx % linearSize - 1));
160171
}
161172

162173
result -= lerp(floored, ceiled, alpha);
@@ -166,13 +177,13 @@ struct BoxSampler
166177
{
167178
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
168179
const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1);
169-
result -= (1 - abs(leftIdx)) * (last - lastMinusOne);
180+
result -= (leftIdx + 1) * (last - lastMinusOne);
170181
break;
171182
}
172183
}
173184
}
174185

175-
return result * normalizationFactor;
186+
return result / (2 * radius + 1);
176187
}
177188
};
178189

0 commit comments

Comments
 (0)