Skip to content

Commit aae1dd8

Browse files
committed
PR review refactor
1 parent f59e68e commit aae1dd8

File tree

1 file changed

+100
-92
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+100
-92
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 100 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,102 +23,76 @@ namespace fft
2323
{
2424

2525
// ---------------------------------- Utils -----------------------------------------------
26-
template<typename SharedMemoryAdaptor, typename Scalar>
27-
struct exchangeValues
26+
27+
// No need to expose these
28+
namespace impl
2829
{
29-
static void __call(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
30+
template<typename SharedMemoryAdaptor, typename Scalar>
31+
struct exchangeValues
3032
{
31-
const bool topHalf = bool(threadID & stride);
32-
// Pack into float vector because ternary operator does not support structs
33-
vector<Scalar, 2> exchanged = topHalf ? vector<Scalar, 2>(lo.real(), lo.imag()) : vector<Scalar, 2>(hi.real(), hi.imag());
34-
shuffleXor<SharedMemoryAdaptor, vector<Scalar, 2> >(exchanged, stride, sharedmemAdaptor);
35-
if (topHalf)
36-
{
37-
lo.real(exchanged.x);
38-
lo.imag(exchanged.y);
39-
}
40-
else
33+
static void __call(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
4134
{
42-
hi.real(exchanged.x);
43-
hi.imag(exchanged.y);
35+
const bool topHalf = bool(threadID & stride);
36+
// Pack into float vector because ternary operator does not support structs
37+
vector<Scalar, 2> exchanged = topHalf ? vector<Scalar, 2>(lo.real(), lo.imag()) : vector<Scalar, 2>(hi.real(), hi.imag());
38+
shuffleXor<SharedMemoryAdaptor, vector<Scalar, 2> >(exchanged, stride, sharedmemAdaptor);
39+
if (topHalf)
40+
{
41+
lo.real(exchanged.x);
42+
lo.imag(exchanged.y);
43+
}
44+
else
45+
{
46+
hi.real(exchanged.x);
47+
hi.imag(exchanged.y);
48+
}
4449
}
45-
}
46-
};
47-
48-
// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
49-
template <typename scalar_t, uint32_t WorkgroupSize>
50-
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) * WorkgroupSize;
51-
50+
};
5251

53-
template<uint32_t N, uint32_t H>
54-
enable_if_t<H <= N, uint32_t> bitShiftRightHigher(uint32_t i)
55-
{
56-
// Highest H bits are numbered N-1 through N - H
57-
// N - H is then the middle bit
58-
// Lowest bits numbered from 0 through N - H - 1
59-
uint32_t low = i & ((1 << (N - H)) - 1);
60-
uint32_t mid = i & (1 << (N - H));
61-
uint32_t high = i & ~((1 << (N - H + 1)) - 1);
62-
63-
high >>= 1;
64-
mid <<= H - 1;
65-
66-
return mid | high | low;
67-
}
52+
template<uint16_t N, uint16_t H>
53+
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftRightHigher(uint32_t i)
54+
{
55+
// Highest H bits are numbered N-1 through N - H
56+
// N - H is then the middle bit
57+
// Lowest bits numbered from 0 through N - H - 1
58+
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
59+
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
60+
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~((1 << (N - H + 1)) - 1);
6861

69-
template<uint32_t N, uint32_t H>
70-
enable_if_t<H <= N, uint32_t> bitShiftLeftHigher(uint32_t i)
71-
{
72-
// Highest H bits are numbered N-1 through N - H
73-
// N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
74-
// Lowest bits numbered from 0 through N - H - 1
75-
uint32_t low = i & ((1 << (N - H)) - 1);
76-
uint32_t mid = i & (~((1 << (N - H)) - 1) | ~(1 << (N - 1)));
77-
uint32_t high = i & (1 << (N - 1));
62+
uint32_t low = i & lowMask;
63+
uint32_t mid = i & midMask;
64+
uint32_t high = i & highMask;
7865

79-
mid <<= 1;
80-
high >>= H - 1;
66+
high >>= 1;
67+
mid <<= H - 1;
8168

82-
return mid | high | low;
83-
}
69+
return mid | high | low;
70+
}
8471

85-
// This function maps the index `idx` in the output array of a Forward FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = output[idx]`
86-
// This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
87-
template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
88-
uint32_t getFrequencyIndex(uint32_t outputIdx)
89-
{
90-
NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t(mpl::log2<ElementsPerInvocation>::value);
91-
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t(mpl::log2<WorkgroupSize>::value);
72+
template<uint16_t N, uint16_t H>
73+
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftLeftHigher(uint32_t i)
74+
{
75+
// Highest H bits are numbered N-1 through N - H
76+
// N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
77+
// Lowest bits numbered from 0 through N - H - 1
78+
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
79+
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~((1 << (N - H)) - 1) | ~(1 << (N - 1));
80+
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1);
9281

93-
return bitShiftRightHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1>(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFT_SIZE_LOG_2));
94-
}
82+
uint32_t low = i & lowMask;
83+
uint32_t mid = i & midMask;
84+
uint32_t high = i & highMask;
9585

96-
// This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Forward FFT such that `DFT[freqIdx] = output[idx]`
97-
// It is essentially the inverse of `getFrequencyIndex`
98-
template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
99-
uint32_t getOutputIndex(uint32_t freqIdx)
100-
{
101-
NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t(mpl::log2<ElementsPerInvocation>::value);
102-
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t(mpl::log2<WorkgroupSize>::value);
86+
mid <<= 1;
87+
high >>= H - 1;
10388

104-
return glsl::bitfieldReverse<uint32_t>(bitShiftLeftHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1>(freqIdx)) >> (32 - FFT_SIZE_LOG_2);
105-
}
106-
107-
// Mirrors an index about the Nyquist frequency
108-
template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
109-
uint32_t mirror(uint32_t idx)
110-
{
111-
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE = WorkgroupSize * uint32_t(ElementsPerInvocation);
112-
return (FFT_SIZE - idx) & (FFT_SIZE - 1);
113-
}
89+
return mid | high | low;
90+
}
91+
} //namespace impl
11492

115-
// When packing real FFTs a common operation is to get `DFT[T]` and `DFT[-T]` to unpack the result of a packed real FFT.
116-
// Given an index `idx` into the Nabla-ordered DFT such that `output[idx] = DFT[T]`, this function is such that `output[getNegativeIndex(idx)] = DFT[-T]`
117-
template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
118-
uint32_t getNegativeIndex(uint32_t idx)
119-
{
120-
return getOutputIndex<ElementsPerInvocation, WorkgroupSize>(mirror<ElementsPerInvocation, WorkgroupSize>(getFrequencyIndex<ElementsPerInvocation, WorkgroupSize>(idx)));
121-
}
93+
// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
94+
template <typename scalar_t, uint16_t WorkgroupSize>
95+
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) * WorkgroupSize;
12296

12397
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
12498
template<typename Scalar>
@@ -129,11 +103,45 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
129103
lo = x;
130104
}
131105

106+
template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
107+
struct FFTIndexingUtils
108+
{
109+
// This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
110+
// This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
111+
static uint32_t getDFTIndex(uint32_t outputIdx)
112+
{
113+
return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFTSizeLog2));
114+
}
115+
116+
// This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
117+
// It is essentially the inverse of `getDFTIndex`
118+
static uint32_t getNablaIndex(uint32_t freqIdx)
119+
{
120+
return glsl::bitfieldReverse<uint32_t>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(freqIdx)) >> (32 - FFTSizeLog2);
121+
}
122+
123+
// Mirrors an index about the Nyquist frequency in the DFT order
124+
static uint32_t getDFTMirrorIndex(uint32_t idx)
125+
{
126+
return (FFTSize - idx) & (FFTSize - 1);
127+
}
128+
129+
// Given an index `idx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
130+
static uint32_t getNablaMirrorIndex(uint32_t idx)
131+
{
132+
return getNablaIndex(getDFTMirrorIndex(getDFTIndex(idx)));
133+
}
134+
135+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136+
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137+
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(WorkgroupSize) * uint32_t(ElementsPerInvocation);
138+
};
139+
132140
} //namespace fft
133141

134142
// ----------------------------------- End Utils -----------------------------------------------
135143

136-
template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void>
144+
template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void>
137145
struct FFT;
138146

139147
// For the FFT methods below, we assume:
@@ -153,13 +161,13 @@ struct FFT;
153161
// * void workgroupExecutionAndMemoryBarrier();
154162

155163
// 2 items per invocation forward specialization
156-
template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
164+
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
157165
struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
158166
{
159167
template<typename SharedMemoryAdaptor>
160168
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
161169
{
162-
fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
170+
fft::impl::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
163171

164172
// Get twiddle with k = threadID mod stride, halfN = stride
165173
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID & (stride - 1), stride), lo, hi);
@@ -199,7 +207,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
199207
}
200208

201209
// special last workgroup-shuffle
202-
fft::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
210+
fft::impl::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
203211

204212
// Remember to update the accessor's state
205213
sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -217,7 +225,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
217225

218226

219227
// 2 items per invocation inverse specialization
220-
template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
228+
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
221229
struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
222230
{
223231
template<typename SharedMemoryAdaptor>
@@ -226,7 +234,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
226234
// Get twiddle with k = threadID mod stride, halfN = stride
227235
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID & (stride - 1), stride), lo, hi);
228236

229-
fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
237+
fft::impl::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
230238
}
231239

232240

@@ -255,7 +263,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
255263
sharedmemAdaptor.accessor = sharedmemAccessor;
256264

257265
// special first workgroup-shuffle
258-
fft::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
266+
fft::impl::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
259267

260268
// The bigger steps
261269
[unroll]
@@ -283,7 +291,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
283291
};
284292

285293
// Forward FFT
286-
template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
294+
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
287295
struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
288296
{
289297
template<typename Accessor, typename SharedMemoryAccessor>
@@ -326,7 +334,7 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
326334
};
327335

328336
// Inverse FFT
329-
template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
337+
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
330338
struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
331339
{
332340
template<typename Accessor, typename SharedMemoryAccessor>

0 commit comments

Comments
 (0)