1
+ #include "nbl/builtin/hlsl/cpp_compat.hlsl"
2
+ #include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
3
+ #include "nbl/builtin/hlsl/workgroup/basic.hlsl"
4
+ #include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl"
5
+ #include "nbl/builtin/hlsl/workgroup/scratch_size.hlsl"
6
+ #include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
7
+ #include "nbl/builtin/hlsl/enums.hlsl"
8
+
9
+ namespace nbl
10
+ {
11
+ namespace hlsl
12
+ {
13
+ namespace box_blur
14
+ {
15
+
16
+ template<
17
+ typename DataAccessor,
18
+ typename SharedAccessor,
19
+ typename ScanSharedAccessor,
20
+ typename Sampler,
21
+ uint16_t WorkgroupSize,
22
+ class device_capabilities=void > // TODO: define concepts for the Box1D and apply constraints
23
+ struct Box1D
24
+ {
25
+ // TODO: Generalize later on when Francesco enforces accessor-concepts in `workgroup` and adds a `SharedMemoryAccessor` concept
26
+ struct ScanSharedAccessorWrapper
27
+ {
28
+ void get (const uint16_t ix, NBL_REF_ARG (float32_t) val)
29
+ {
30
+ val = base.template get<float32_t, uint16_t>(ix);
31
+ }
32
+
33
+ void set (const uint16_t ix, const float32_t val)
34
+ {
35
+ base.template set<float32_t, uint16_t>(ix, val);
36
+ }
37
+
38
+ void workgroupExecutionAndMemoryBarrier ()
39
+ {
40
+ base.workgroupExecutionAndMemoryBarrier ();
41
+ }
42
+
43
+ ScanSharedAccessor base;
44
+ };
45
+
46
+ void operator ()(
47
+ NBL_REF_ARG (DataAccessor) data,
48
+ NBL_REF_ARG (SharedAccessor) scratch,
49
+ NBL_REF_ARG (ScanSharedAccessor) scanScratch,
50
+ NBL_REF_ARG (Sampler) boxSampler,
51
+ const uint16_t channel)
52
+ {
53
+ const uint16_t end = data.linearSize ();
54
+ const uint16_t localInvocationIndex = workgroup::SubgroupContiguousIndex ();
55
+
56
+ // prefix sum
57
+ // note the dynamically uniform loop condition
58
+ for (uint16_t baseIx = 0 ; baseIx < end;)
59
+ {
60
+ const uint16_t ix = localInvocationIndex + baseIx;
61
+ float32_t input = data.template get<float32_t>(channel, ix);
62
+ // dynamically uniform condition
63
+ if (baseIx != 0 )
64
+ {
65
+ // take result of previous prefix sum and add it to first element here
66
+ if (localInvocationIndex == 0 )
67
+ input += scratch.template get<float32_t>(baseIx - 1 );
68
+ }
69
+ // need to copy-in / copy-out the accessor cause no references in HLSL - yay!
70
+ ScanSharedAccessorWrapper scanScratchWrapper;
71
+ scanScratchWrapper.base = scanScratch;
72
+ const float32_t sum = workgroup::inclusive_scan<plus<float32_t>, WorkgroupSize, device_capabilities>::template __call (input, scanScratchWrapper);
73
+ scanScratch = scanScratchWrapper.base;
74
+ // loop increment
75
+ baseIx += WorkgroupSize;
76
+ // if doing the last prefix sum, we need to barrier to stop aliasing of temporary scratch for `inclusive_scan` and our scanline
77
+ // TODO: might be worth adding a non-aliased mode as NSight says nr 1 hotspot is barrier waiting in this code
78
+ if (end + ScanSharedAccessor::Size > SharedAccessor::Size)
79
+ scratch.workgroupExecutionAndMemoryBarrier ();
80
+ // save prefix sum results
81
+ if (ix < end)
82
+ scratch.template set<float32_t>(ix, sum);
83
+ // previous prefix sum must have finished before we ask for results
84
+ scratch.workgroupExecutionAndMemoryBarrier ();
85
+ }
86
+
87
+ const float32_t last = end - 1 ;
88
+ const float32_t normalizationFactor = 1.f / (2.f * radius + 1.f );
89
+
90
+ for (float32_t ix = localInvocationIndex; ix < end; ix += WorkgroupSize)
91
+ {
92
+ const float32_t result = boxSampler (scratch, ix, radius, borderColor[channel]);
93
+ data.template set<float32_t>(channel, uint16_t (ix), result * normalizationFactor);
94
+ }
95
+ }
96
+
97
+ vector <float32_t, DataAccessor::Channels> borderColor;
98
+ float32_t radius;
99
+ };
100
+
101
+ template<typename PrefixSumAccessor, typename T>
102
+ struct BoxSampler
103
+ {
104
+ uint16_t wrapMode;
105
+ uint16_t linearSize;
106
+
107
+ T operator ()(NBL_REF_ARG (PrefixSumAccessor) prefixSumAccessor, float32_t ix, float32_t radius, float32_t borderColor)
108
+ {
109
+ const float32_t alpha = radius - floor (radius);
110
+ const float32_t lastIdx = linearSize - 1 ;
111
+ const float32_t rightIdx = float32_t (ix) + radius;
112
+ const float32_t leftIdx = float32_t (ix) - radius;
113
+ const int32_t rightFlIdx = (int32_t)floor (rightIdx);
114
+ const int32_t rightClIdx = (int32_t)ceil (rightIdx);
115
+ const int32_t leftFlIdx = (int32_t)floor (leftIdx);
116
+ const int32_t leftClIdx = (int32_t)ceil (leftIdx);
117
+
118
+ T result = 0 ;
119
+ if (rightFlIdx < linearSize)
120
+ {
121
+ result += lerp (prefixSumAccessor.template get<T, uint32_t>(rightFlIdx), prefixSumAccessor.template get<T, uint32_t>(rightClIdx), alpha);
122
+ }
123
+ else
124
+ {
125
+ switch (wrapMode) {
126
+ case ETC_REPEAT:
127
+ {
128
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
129
+ const T floored = prefixSumAccessor.template get<T, uint32_t>(rightFlIdx % linearSize) + ceil (float32_t (rightFlIdx % lastIdx) / linearSize) * last;
130
+ const T ceiled = prefixSumAccessor.template get<T, uint32_t>(rightClIdx % linearSize) + ceil (float32_t (rightClIdx % lastIdx) / linearSize) * last;
131
+ result += lerp (floored, ceiled, alpha);
132
+ break ;
133
+ }
134
+ case ETC_CLAMP_TO_BORDER:
135
+ {
136
+ result += prefixSumAccessor.template get<T, uint32_t>(lastIdx) + (rightIdx - lastIdx) * borderColor;
137
+ break ;
138
+ }
139
+ case ETC_CLAMP_TO_EDGE:
140
+ {
141
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
142
+ const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1 );
143
+ result += (rightIdx - lastIdx) * (last - lastMinusOne) + last;
144
+ break ;
145
+ }
146
+ case ETC_MIRROR:
147
+ {
148
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
149
+ T floored, ceiled;
150
+ int32_t d = rightFlIdx - lastIdx;
151
+
152
+ if (d % (2 * linearSize) == linearSize)
153
+ floored = ((d + linearSize) / linearSize) * last;
154
+ else
155
+ {
156
+ const uint32_t period = uint32_t (ceil (float32_t (d) / linearSize));
157
+ if ((period & 0x1u) == 1 )
158
+ floored = period * last + last - prefixSumAccessor.template get<T, uint32_t>(lastIdx - uint32_t (d % linearSize));
159
+ else
160
+ floored = period * last + prefixSumAccessor.template get<T, uint32_t>((d - 1 ) % linearSize);
161
+ }
162
+
163
+ d = rightClIdx - lastIdx;
164
+ if (d % (2 * linearSize) == linearSize)
165
+ ceiled = ((d + linearSize) / linearSize) * last;
166
+ else
167
+ {
168
+ const uint32_t period = uint32_t (ceil (float32_t (d) / linearSize));
169
+ if ((period & 0x1u) == 1 )
170
+ ceiled = period * last + last - prefixSumAccessor.template get<T, uint32_t>(lastIdx - uint32_t (d % linearSize));
171
+ else
172
+ ceiled = period * last + prefixSumAccessor.template get<T, uint32_t>((d - 1 ) % linearSize);
173
+ }
174
+
175
+ result += lerp (floored, ceiled, alpha);
176
+ break ;
177
+ }
178
+ case ETC_MIRROR_CLAMP_TO_EDGE:
179
+ {
180
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
181
+ const T first = prefixSumAccessor.template get<T, uint32_t>(0 );
182
+ const T firstPlusOne = prefixSumAccessor.template get<T, uint32_t>(1 );
183
+ result += (rightIdx - lastIdx) * (firstPlusOne - first) + last;
184
+ break ;
185
+ }
186
+ }
187
+ }
188
+
189
+ if (leftFlIdx >= 0 )
190
+ {
191
+ result -= lerp (prefixSumAccessor.template get<T, uint32_t>(leftFlIdx), prefixSumAccessor.template get<T, uint32_t>(leftClIdx), alpha);
192
+ }
193
+ else
194
+ {
195
+ switch (wrapMode) {
196
+ case ETC_REPEAT:
197
+ {
198
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
199
+ const T floored = prefixSumAccessor.template get<T, uint32_t>(abs (leftFlIdx) % linearSize) + ceil (T (leftFlIdx) / linearSize) * last;
200
+ const T ceiled = prefixSumAccessor.template get<T, uint32_t>(abs (leftClIdx) % linearSize) + ceil (float32_t (leftClIdx) / linearSize) * last;
201
+ result -= lerp (floored, ceiled, alpha);
202
+ break ;
203
+ }
204
+ case ETC_CLAMP_TO_BORDER:
205
+ {
206
+ result -= prefixSumAccessor.template get<T, uint32_t>(0 ) + leftIdx * borderColor;
207
+ break ;
208
+ }
209
+ case ETC_CLAMP_TO_EDGE:
210
+ {
211
+ result -= leftIdx * prefixSumAccessor.template get<T, uint32_t>(0 );
212
+ break ;
213
+ }
214
+ case ETC_MIRROR:
215
+ {
216
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
217
+ T floored, ceiled;
218
+
219
+ if (abs (leftFlIdx + 1 ) % (2 * linearSize) == 0 )
220
+ floored = -(abs (leftFlIdx + 1 ) / linearSize) * last;
221
+ else
222
+ {
223
+ const uint32_t period = uint32_t (ceil (float32_t (abs (leftFlIdx + 1 )) / linearSize));
224
+ if ((period & 0x1u) == 1 )
225
+ floored = -(period - 1 ) * last - prefixSumAccessor.template get<T, uint32_t>((abs (leftFlIdx + 1 ) - 1 ) % linearSize);
226
+ else
227
+ floored = -(period - 1 ) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftFlIdx + 1 ) % linearSize - 1 ));
228
+ }
229
+
230
+ if (leftClIdx == 0 ) // Special case, wouldn't be possible for `floored` above
231
+ ceiled = 0 ;
232
+ else if (abs (leftClIdx + 1 ) % (2 * linearSize) == 0 )
233
+ ceiled = -(abs (leftClIdx + 1 ) / linearSize) * last;
234
+ else
235
+ {
236
+ const uint32_t period = uint32_t (ceil (float32_t (abs (leftClIdx + 1 )) / linearSize));
237
+ if ((period & 0x1u) == 1 )
238
+ ceiled = -(period - 1 ) * last - prefixSumAccessor.template get<T, uint32_t>((abs (leftClIdx + 1 ) - 1 ) % linearSize);
239
+ else
240
+ ceiled = -(period - 1 ) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftClIdx + 1 ) % linearSize - 1 ));
241
+ }
242
+
243
+ result -= lerp (floored, ceiled, alpha);
244
+ break ;
245
+ }
246
+ case ETC_MIRROR_CLAMP_TO_EDGE:
247
+ {
248
+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
249
+ const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1 );
250
+ result -= leftIdx * (last - lastMinusOne);
251
+ break ;
252
+ }
253
+ }
254
+ }
255
+
256
+ return result;
257
+ }
258
+ };
259
+
260
+ }
261
+ }
262
+ }
0 commit comments