@@ -23,85 +23,125 @@ namespace fft
23
23
{
24
24
25
25
// ---------------------------------- Utils -----------------------------------------------
26
- template<typename SharedMemoryAdaptor, typename Scalar>
27
- struct exchangeValues;
28
26
29
- template<typename SharedMemoryAdaptor>
30
- struct exchangeValues<SharedMemoryAdaptor, float16_t>
27
+ // No need to expose these
28
+ namespace impl
31
29
{
32
- static void __call (NBL_REF_ARG (complex_t<float16_t>) lo, NBL_REF_ARG (complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
30
+ template<typename SharedMemoryAdaptor, typename Scalar>
31
+ struct exchangeValues
33
32
{
34
- const bool topHalf = bool (threadID & stride);
35
- // Pack two halves into a single uint32_t
36
- uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real (), lo.imag ()) : float16_t2 (hi.real (), hi.imag ()));
37
- shuffleXor<SharedMemoryAdaptor, uint32_t>::__call (toExchange, stride, sharedmemAdaptor);
38
- float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
39
- if (topHalf)
33
+ static void __call (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
40
34
{
41
- lo.real (exchanged.x);
42
- lo.imag (exchanged.y);
35
+ const bool topHalf = bool (threadID & stride);
36
+ // Pack into float vector because ternary operator does not support structs
37
+ vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
38
+ shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
39
+ if (topHalf)
40
+ {
41
+ lo.real (exchanged.x);
42
+ lo.imag (exchanged.y);
43
+ }
44
+ else
45
+ {
46
+ hi.real (exchanged.x);
47
+ hi.imag (exchanged.y);
48
+ }
43
49
}
44
- else
45
- {
46
- hi.real (exchanged.x);
47
- lo.imag (exchanged.y);
48
- }
49
- }
50
- };
50
+ };
51
51
52
- template<typename SharedMemoryAdaptor>
53
- struct exchangeValues<SharedMemoryAdaptor, float32_t>
54
- {
55
- static void __call (NBL_REF_ARG (complex_t<float32_t>) lo, NBL_REF_ARG (complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
52
+ template<uint16_t N, uint16_t H>
53
+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftRightHigher (uint32_t i)
56
54
{
57
- const bool topHalf = bool (threadID & stride);
58
- // pack into `float32_t2` because ternary operator doesn't support structs
59
- float32_t2 exchanged = topHalf ? float32_t2 (lo.real (), lo.imag ()) : float32_t2 (hi.real (), hi.imag ());
60
- shuffleXor<SharedMemoryAdaptor, float32_t2>::__call (exchanged, stride, sharedmemAdaptor);
61
- if (topHalf)
62
- {
63
- lo.real (exchanged.x);
64
- lo.imag (exchanged.y);
65
- }
66
- else
67
- {
68
- hi.real (exchanged.x);
69
- hi.imag (exchanged.y);
70
- }
55
+ // Highest H bits are numbered N-1 through N - H
56
+ // N - H is then the middle bit
57
+ // Lowest bits numbered from 0 through N - H - 1
58
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
59
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
60
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask);
61
+
62
+ uint32_t low = i & lowMask;
63
+ uint32_t mid = i & midMask;
64
+ uint32_t high = i & highMask;
65
+
66
+ high >>= 1 ;
67
+ mid <<= H - 1 ;
68
+
69
+ return mid | high | low;
71
70
}
72
- };
73
71
74
- template<typename SharedMemoryAdaptor>
75
- struct exchangeValues<SharedMemoryAdaptor, float64_t>
76
- {
77
- static void __call (NBL_REF_ARG (complex_t<float64_t>) lo, NBL_REF_ARG (complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
72
+ template<uint16_t N, uint16_t H>
73
+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftLeftHigher (uint32_t i)
78
74
{
79
- const bool topHalf = bool (threadID & stride);
80
- // pack into `float64_t2` because ternary operator doesn't support structs
81
- float64_t2 exchanged = topHalf ? float64_t2 (lo.real (), lo.imag ()) : float64_t2 (hi.real (), hi.imag ());
82
- shuffleXor<SharedMemoryAdaptor, float64_t2 >::__call (exchanged, stride, sharedmemAdaptor);
83
- if (topHalf)
84
- {
85
- lo.real (exchanged.x);
86
- lo.imag (exchanged.y);
87
- }
88
- else
89
- {
90
- hi.real (exchanged.x);
91
- hi.imag (exchanged.y);
92
- }
75
+ // Highest H bits are numbered N-1 through N - H
76
+ // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
77
+ // Lowest bits numbered from 0 through N - H - 1
78
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
79
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1 );
80
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask);
81
+
82
+ uint32_t low = i & lowMask;
83
+ uint32_t mid = i & midMask;
84
+ uint32_t high = i & highMask;
85
+
86
+ mid <<= 1 ;
87
+ high >>= H - 1 ;
88
+
89
+ return mid | high | low;
93
90
}
94
- };
91
+ } //namespace impl
95
92
96
93
// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
97
- template <typename scalar_t, uint32_t WorkgroupSize>
94
+ template <typename scalar_t, uint16_t WorkgroupSize>
98
95
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
99
96
97
+ // Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
98
+ template<typename Scalar>
99
+ void unpack (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi)
100
+ {
101
+ complex_t<Scalar> x = (lo + conj (hi)) * Scalar (0.5 );
102
+ hi = rotateRight<Scalar>(lo - conj (hi)) * Scalar (0.5 );
103
+ lo = x;
104
+ }
105
+
106
+ template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
107
+ struct FFTIndexingUtils
108
+ {
109
+ // This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
110
+ // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
111
+ static uint32_t getDFTIndex (uint32_t outputIdx)
112
+ {
113
+ return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFTSizeLog2));
114
+ }
115
+
116
+ // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
117
+ // It is essentially the inverse of `getDFTIndex`
118
+ static uint32_t getNablaIndex (uint32_t freqIdx)
119
+ {
120
+ return glsl::bitfieldReverse<uint32_t>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(freqIdx)) >> (32 - FFTSizeLog2);
121
+ }
122
+
123
+ // Mirrors an index about the Nyquist frequency in the DFT order
124
+ static uint32_t getDFTMirrorIndex (uint32_t idx)
125
+ {
126
+ return (FFTSize - idx) & (FFTSize - 1 );
127
+ }
128
+
129
+ // Given an index `idx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
130
+ static uint32_t getNablaMirrorIndex (uint32_t idx)
131
+ {
132
+ return getNablaIndex (getDFTMirrorIndex (getDFTIndex (idx)));
133
+ }
134
+
135
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t (WorkgroupSize) * uint32_t (ElementsPerInvocation);
138
+ };
139
+
100
140
} //namespace fft
101
141
102
142
// ----------------------------------- End Utils -----------------------------------------------
103
143
104
- template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void >
144
+ template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void >
105
145
struct FFT;
106
146
107
147
// For the FFT methods below, we assume:
@@ -121,13 +161,13 @@ struct FFT;
121
161
// * void workgroupExecutionAndMemoryBarrier();
122
162
123
163
// 2 items per invocation forward specialization
124
- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
164
+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
125
165
struct FFT<2 ,false , WorkgroupSize, Scalar, device_capabilities>
126
166
{
127
167
template<typename SharedMemoryAdaptor>
128
168
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
129
169
{
130
- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
170
+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
131
171
132
172
// Get twiddle with k = threadID mod stride, halfN = stride
133
173
hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
@@ -167,7 +207,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
167
207
}
168
208
169
209
// special last workgroup-shuffle
170
- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
210
+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
171
211
172
212
// Remember to update the accessor's state
173
213
sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -185,7 +225,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
185
225
186
226
187
227
// 2 items per invocation inverse specialization
188
- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
228
+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
189
229
struct FFT<2 ,true , WorkgroupSize, Scalar, device_capabilities>
190
230
{
191
231
template<typename SharedMemoryAdaptor>
@@ -194,7 +234,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
194
234
// Get twiddle with k = threadID mod stride, halfN = stride
195
235
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
196
236
197
- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
237
+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
198
238
}
199
239
200
240
@@ -223,7 +263,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
223
263
sharedmemAdaptor.accessor = sharedmemAccessor;
224
264
225
265
// special first workgroup-shuffle
226
- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
266
+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
227
267
228
268
// The bigger steps
229
269
[unroll]
@@ -251,7 +291,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
251
291
};
252
292
253
293
// Forward FFT
254
- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
294
+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
255
295
struct FFT<K, false , WorkgroupSize, Scalar, device_capabilities>
256
296
{
257
297
template<typename Accessor, typename SharedMemoryAccessor>
@@ -294,7 +334,7 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
294
334
};
295
335
296
336
// Inverse FFT
297
- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
337
+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
298
338
struct FFT<K, true , WorkgroupSize, Scalar, device_capabilities>
299
339
{
300
340
template<typename Accessor, typename SharedMemoryAccessor>
0 commit comments