Skip to content

Commit 19973ac

Browse files
committed
box_blur: move into here from example code
Signed-off-by: Ali Cheraghi <alichraghi@proton.me>
1 parent 83db408 commit 19973ac

File tree

2 files changed

+262
-2
lines changed

2 files changed

+262
-2
lines changed

include/nbl/builtin/hlsl/enums.hlsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ enum TextureClamp : uint16_t
4444
ETC_MIRROR,
4545
//! Texture is mirrored once and then clamped to edge
4646
ETC_MIRROR_CLAMP_TO_EDGE,
47-
//! Texture is mirrored once and then clamped to border
48-
ETC_MIRROR_CLAMP_TO_BORDER,
4947

5048
ETC_COUNT
5149
};
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
2+
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
3+
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
4+
#include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl"
5+
#include "nbl/builtin/hlsl/workgroup/scratch_size.hlsl"
6+
#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
7+
#include "nbl/builtin/hlsl/enums.hlsl"
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
namespace box_blur
14+
{
15+
16+
template<
17+
typename DataAccessor,
18+
typename SharedAccessor,
19+
typename ScanSharedAccessor,
20+
typename Sampler,
21+
uint16_t WorkgroupSize,
22+
class device_capabilities=void> // TODO: define concepts for the Box1D and apply constraints
23+
struct Box1D
24+
{
25+
// TODO: Generalize later on when Francesco enforces accessor-concepts in `workgroup` and adds a `SharedMemoryAccessor` concept
26+
struct ScanSharedAccessorWrapper
27+
{
28+
void get(const uint16_t ix, NBL_REF_ARG(float32_t) val)
29+
{
30+
val = base.template get<float32_t, uint16_t>(ix);
31+
}
32+
33+
void set(const uint16_t ix, const float32_t val)
34+
{
35+
base.template set<float32_t, uint16_t>(ix, val);
36+
}
37+
38+
void workgroupExecutionAndMemoryBarrier()
39+
{
40+
base.workgroupExecutionAndMemoryBarrier();
41+
}
42+
43+
ScanSharedAccessor base;
44+
};
45+
46+
void operator()(
47+
NBL_REF_ARG(DataAccessor) data,
48+
NBL_REF_ARG(SharedAccessor) scratch,
49+
NBL_REF_ARG(ScanSharedAccessor) scanScratch,
50+
NBL_REF_ARG(Sampler) boxSampler,
51+
const uint16_t channel)
52+
{
53+
const uint16_t end = data.linearSize();
54+
const uint16_t localInvocationIndex = workgroup::SubgroupContiguousIndex();
55+
56+
// prefix sum
57+
// note the dynamically uniform loop condition
58+
for (uint16_t baseIx = 0; baseIx < end;)
59+
{
60+
const uint16_t ix = localInvocationIndex + baseIx;
61+
float32_t input = data.template get<float32_t>(channel, ix);
62+
// dynamically uniform condition
63+
if (baseIx != 0)
64+
{
65+
// take result of previous prefix sum and add it to first element here
66+
if (localInvocationIndex == 0)
67+
input += scratch.template get<float32_t>(baseIx - 1);
68+
}
69+
// need to copy-in / copy-out the accessor cause no references in HLSL - yay!
70+
ScanSharedAccessorWrapper scanScratchWrapper;
71+
scanScratchWrapper.base = scanScratch;
72+
const float32_t sum = workgroup::inclusive_scan<plus<float32_t>, WorkgroupSize, device_capabilities>::template __call(input, scanScratchWrapper);
73+
scanScratch = scanScratchWrapper.base;
74+
// loop increment
75+
baseIx += WorkgroupSize;
76+
// if doing the last prefix sum, we need to barrier to stop aliasing of temporary scratch for `inclusive_scan` and our scanline
77+
// TODO: might be worth adding a non-aliased mode as NSight says nr 1 hotspot is barrier waiting in this code
78+
if (end + ScanSharedAccessor::Size > SharedAccessor::Size)
79+
scratch.workgroupExecutionAndMemoryBarrier();
80+
// save prefix sum results
81+
if (ix < end)
82+
scratch.template set<float32_t>(ix, sum);
83+
// previous prefix sum must have finished before we ask for results
84+
scratch.workgroupExecutionAndMemoryBarrier();
85+
}
86+
87+
const float32_t last = end - 1;
88+
const float32_t normalizationFactor = 1.f / (2.f * radius + 1.f);
89+
90+
for (float32_t ix = localInvocationIndex; ix < end; ix += WorkgroupSize)
91+
{
92+
const float32_t result = boxSampler(scratch, ix, radius, borderColor[channel]);
93+
data.template set<float32_t>(channel, uint16_t(ix), result * normalizationFactor);
94+
}
95+
}
96+
97+
vector<float32_t, DataAccessor::Channels> borderColor;
98+
float32_t radius;
99+
};
100+
101+
template<typename PrefixSumAccessor, typename T>
102+
struct BoxSampler
103+
{
104+
uint16_t wrapMode;
105+
uint16_t linearSize;
106+
107+
T operator()(NBL_REF_ARG(PrefixSumAccessor) prefixSumAccessor, float32_t ix, float32_t radius, float32_t borderColor)
108+
{
109+
const float32_t alpha = radius - floor(radius);
110+
const float32_t lastIdx = linearSize - 1;
111+
const float32_t rightIdx = float32_t(ix) + radius;
112+
const float32_t leftIdx = float32_t(ix) - radius;
113+
const int32_t rightFlIdx = (int32_t)floor(rightIdx);
114+
const int32_t rightClIdx = (int32_t)ceil(rightIdx);
115+
const int32_t leftFlIdx = (int32_t)floor(leftIdx);
116+
const int32_t leftClIdx = (int32_t)ceil(leftIdx);
117+
118+
T result = 0;
119+
if (rightFlIdx < linearSize)
120+
{
121+
result += lerp(prefixSumAccessor.template get<T, uint32_t>(rightFlIdx), prefixSumAccessor.template get<T, uint32_t>(rightClIdx), alpha);
122+
}
123+
else
124+
{
125+
switch (wrapMode) {
126+
case ETC_REPEAT:
127+
{
128+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
129+
const T floored = prefixSumAccessor.template get<T, uint32_t>(rightFlIdx % linearSize) + ceil(float32_t(rightFlIdx % lastIdx) / linearSize) * last;
130+
const T ceiled = prefixSumAccessor.template get<T, uint32_t>(rightClIdx % linearSize) + ceil(float32_t(rightClIdx % lastIdx) / linearSize) * last;
131+
result += lerp(floored, ceiled, alpha);
132+
break;
133+
}
134+
case ETC_CLAMP_TO_BORDER:
135+
{
136+
result += prefixSumAccessor.template get<T, uint32_t>(lastIdx) + (rightIdx - lastIdx) * borderColor;
137+
break;
138+
}
139+
case ETC_CLAMP_TO_EDGE:
140+
{
141+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
142+
const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1);
143+
result += (rightIdx - lastIdx) * (last - lastMinusOne) + last;
144+
break;
145+
}
146+
case ETC_MIRROR:
147+
{
148+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
149+
T floored, ceiled;
150+
int32_t d = rightFlIdx - lastIdx;
151+
152+
if (d % (2 * linearSize) == linearSize)
153+
floored = ((d + linearSize) / linearSize) * last;
154+
else
155+
{
156+
const uint32_t period = uint32_t(ceil(float32_t(d) / linearSize));
157+
if ((period & 0x1u) == 1)
158+
floored = period * last + last - prefixSumAccessor.template get<T, uint32_t>(lastIdx - uint32_t(d % linearSize));
159+
else
160+
floored = period * last + prefixSumAccessor.template get<T, uint32_t>((d - 1) % linearSize);
161+
}
162+
163+
d = rightClIdx - lastIdx;
164+
if (d % (2 * linearSize) == linearSize)
165+
ceiled = ((d + linearSize) / linearSize) * last;
166+
else
167+
{
168+
const uint32_t period = uint32_t(ceil(float32_t(d) / linearSize));
169+
if ((period & 0x1u) == 1)
170+
ceiled = period * last + last - prefixSumAccessor.template get<T, uint32_t>(lastIdx - uint32_t(d % linearSize));
171+
else
172+
ceiled = period * last + prefixSumAccessor.template get<T, uint32_t>((d - 1) % linearSize);
173+
}
174+
175+
result += lerp(floored, ceiled, alpha);
176+
break;
177+
}
178+
case ETC_MIRROR_CLAMP_TO_EDGE:
179+
{
180+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
181+
const T first = prefixSumAccessor.template get<T, uint32_t>(0);
182+
const T firstPlusOne = prefixSumAccessor.template get<T, uint32_t>(1);
183+
result += (rightIdx - lastIdx) * (firstPlusOne - first) + last;
184+
break;
185+
}
186+
}
187+
}
188+
189+
if (leftFlIdx >= 0)
190+
{
191+
result -= lerp(prefixSumAccessor.template get<T, uint32_t>(leftFlIdx), prefixSumAccessor.template get<T, uint32_t>(leftClIdx), alpha);
192+
}
193+
else
194+
{
195+
switch (wrapMode) {
196+
case ETC_REPEAT:
197+
{
198+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
199+
const T floored = prefixSumAccessor.template get<T, uint32_t>(abs(leftFlIdx) % linearSize) + ceil(T(leftFlIdx) / linearSize) * last;
200+
const T ceiled = prefixSumAccessor.template get<T, uint32_t>(abs(leftClIdx) % linearSize) + ceil(float32_t(leftClIdx) / linearSize) * last;
201+
result -= lerp(floored, ceiled, alpha);
202+
break;
203+
}
204+
case ETC_CLAMP_TO_BORDER:
205+
{
206+
result -= prefixSumAccessor.template get<T, uint32_t>(0) + leftIdx * borderColor;
207+
break;
208+
}
209+
case ETC_CLAMP_TO_EDGE:
210+
{
211+
result -= leftIdx * prefixSumAccessor.template get<T, uint32_t>(0);
212+
break;
213+
}
214+
case ETC_MIRROR:
215+
{
216+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
217+
T floored, ceiled;
218+
219+
if (abs(leftFlIdx + 1) % (2 * linearSize) == 0)
220+
floored = -(abs(leftFlIdx + 1) / linearSize) * last;
221+
else
222+
{
223+
const uint32_t period = uint32_t(ceil(float32_t(abs(leftFlIdx + 1)) / linearSize));
224+
if ((period & 0x1u) == 1)
225+
floored = -(period - 1) * last - prefixSumAccessor.template get<T, uint32_t>((abs(leftFlIdx + 1) - 1) % linearSize);
226+
else
227+
floored = -(period - 1) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftFlIdx + 1) % linearSize - 1));
228+
}
229+
230+
if (leftClIdx == 0) // Special case, wouldn't be possible for `floored` above
231+
ceiled = 0;
232+
else if (abs(leftClIdx + 1) % (2 * linearSize) == 0)
233+
ceiled = -(abs(leftClIdx + 1) / linearSize) * last;
234+
else
235+
{
236+
const uint32_t period = uint32_t(ceil(float32_t(abs(leftClIdx + 1)) / linearSize));
237+
if ((period & 0x1u) == 1)
238+
ceiled = -(period - 1) * last - prefixSumAccessor.template get<T, uint32_t>((abs(leftClIdx + 1) - 1) % linearSize);
239+
else
240+
ceiled = -(period - 1) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftClIdx + 1) % linearSize - 1));
241+
}
242+
243+
result -= lerp(floored, ceiled, alpha);
244+
break;
245+
}
246+
case ETC_MIRROR_CLAMP_TO_EDGE:
247+
{
248+
const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
249+
const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1);
250+
result -= leftIdx * (last - lastMinusOne);
251+
break;
252+
}
253+
}
254+
}
255+
256+
return result;
257+
}
258+
};
259+
260+
}
261+
}
262+
}

0 commit comments

Comments
 (0)