@@ -23,102 +23,76 @@ namespace fft
23
23
{
24
24
25
25
// ---------------------------------- Utils -----------------------------------------------
26
- template<typename SharedMemoryAdaptor, typename Scalar>
27
- struct exchangeValues
26
+
27
+ // No need to expose these
28
+ namespace impl
28
29
{
29
- static void __call (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
30
+ template<typename SharedMemoryAdaptor, typename Scalar>
31
+ struct exchangeValues
30
32
{
31
- const bool topHalf = bool (threadID & stride);
32
- // Pack into float vector because ternary operator does not support structs
33
- vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
34
- shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
35
- if (topHalf)
36
- {
37
- lo.real (exchanged.x);
38
- lo.imag (exchanged.y);
39
- }
40
- else
33
+ static void __call (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
41
34
{
42
- hi.real (exchanged.x);
43
- hi.imag (exchanged.y);
35
+ const bool topHalf = bool (threadID & stride);
36
+ // Pack into float vector because ternary operator does not support structs
37
+ vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
38
+ shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
39
+ if (topHalf)
40
+ {
41
+ lo.real (exchanged.x);
42
+ lo.imag (exchanged.y);
43
+ }
44
+ else
45
+ {
46
+ hi.real (exchanged.x);
47
+ hi.imag (exchanged.y);
48
+ }
44
49
}
45
- }
46
- };
47
-
48
- // Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
49
- template <typename scalar_t, uint32_t WorkgroupSize>
50
- NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
51
-
50
+ };
52
51
53
- template<uint32_t N, uint32_t H>
54
- enable_if_t<H <= N, uint32_t> bitShiftRightHigher (uint32_t i)
55
- {
56
- // Highest H bits are numbered N-1 through N - H
57
- // N - H is then the middle bit
58
- // Lowest bits numbered from 0 through N - H - 1
59
- uint32_t low = i & ((1 << (N - H)) - 1 );
60
- uint32_t mid = i & (1 << (N - H));
61
- uint32_t high = i & ~((1 << (N - H + 1 )) - 1 );
62
-
63
- high >>= 1 ;
64
- mid <<= H - 1 ;
65
-
66
- return mid | high | low;
67
- }
52
+ template<uint16_t N, uint16_t H>
53
+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftRightHigher (uint32_t i)
54
+ {
55
+ // Highest H bits are numbered N-1 through N - H
56
+ // N - H is then the middle bit
57
+ // Lowest bits numbered from 0 through N - H - 1
58
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
59
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
60
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~((1 << (N - H + 1 )) - 1 );
68
61
69
- template<uint32_t N, uint32_t H>
70
- enable_if_t<H <= N, uint32_t> bitShiftLeftHigher (uint32_t i)
71
- {
72
- // Highest H bits are numbered N-1 through N - H
73
- // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
74
- // Lowest bits numbered from 0 through N - H - 1
75
- uint32_t low = i & ((1 << (N - H)) - 1 );
76
- uint32_t mid = i & (~((1 << (N - H)) - 1 ) | ~(1 << (N - 1 )));
77
- uint32_t high = i & (1 << (N - 1 ));
62
+ uint32_t low = i & lowMask;
63
+ uint32_t mid = i & midMask;
64
+ uint32_t high = i & highMask;
78
65
79
- mid << = 1 ;
80
- high >> = H - 1 ;
66
+ high >> = 1 ;
67
+ mid << = H - 1 ;
81
68
82
- return mid | high | low;
83
- }
69
+ return mid | high | low;
70
+ }
84
71
85
- // This function maps the index `idx` in the output array of a Forward FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = output[idx]`
86
- // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
87
- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
88
- uint32_t getFrequencyIndex (uint32_t outputIdx)
89
- {
90
- NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
91
- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
72
+ template<uint16_t N, uint16_t H>
73
+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftLeftHigher (uint32_t i)
74
+ {
75
+ // Highest H bits are numbered N-1 through N - H
76
+ // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
77
+ // Lowest bits numbered from 0 through N - H - 1
78
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
79
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~((1 << (N - H)) - 1 ) | ~(1 << (N - 1 ));
80
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1 );
92
81
93
- return bitShiftRightHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFT_SIZE_LOG_2));
94
- }
82
+ uint32_t low = i & lowMask;
83
+ uint32_t mid = i & midMask;
84
+ uint32_t high = i & highMask;
95
85
96
- // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Forward FFT such that `DFT[freqIdx] = output[idx]`
97
- // It is essentially the inverse of `getFrequencyIndex`
98
- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
99
- uint32_t getOutputIndex (uint32_t freqIdx)
100
- {
101
- NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
102
- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
86
+ mid <<= 1 ;
87
+ high >>= H - 1 ;
103
88
104
- return glsl::bitfieldReverse<uint32_t>(bitShiftLeftHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(freqIdx)) >> (32 - FFT_SIZE_LOG_2);
105
- }
106
-
107
- // Mirrors an index about the Nyquist frequency
108
- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
109
- uint32_t mirror (uint32_t idx)
110
- {
111
- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE = WorkgroupSize * uint32_t (ElementsPerInvocation);
112
- return (FFT_SIZE - idx) & (FFT_SIZE - 1 );
113
- }
89
+ return mid | high | low;
90
+ }
91
+ } //namespace impl
114
92
115
- // When packing real FFTs a common operation is to get `DFT[T]` and `DFT[-T]` to unpack the result of a packed real FFT.
116
- // Given an index `idx` into the Nabla-ordered DFT such that `output[idx] = DFT[T]`, this function is such that `output[getNegativeIndex(idx)] = DFT[-T]`
117
- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
118
- uint32_t getNegativeIndex (uint32_t idx)
119
- {
120
- return getOutputIndex<ElementsPerInvocation, WorkgroupSize>(mirror <ElementsPerInvocation, WorkgroupSize>(getFrequencyIndex<ElementsPerInvocation, WorkgroupSize>(idx)));
121
- }
93
+ // Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
94
+ template <typename scalar_t, uint16_t WorkgroupSize>
95
+ NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
122
96
123
97
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
124
98
template<typename Scalar>
@@ -129,11 +103,45 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
129
103
lo = x;
130
104
}
131
105
106
+ template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
107
+ struct FFTIndexingUtils
108
+ {
109
+ // This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
110
+ // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
111
+ static uint32_t getDFTIndex (uint32_t outputIdx)
112
+ {
113
+ return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFTSizeLog2));
114
+ }
115
+
116
+ // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
117
+ // It is essentially the inverse of `getDFTIndex`
118
+ static uint32_t getNablaIndex (uint32_t freqIdx)
119
+ {
120
+ return glsl::bitfieldReverse<uint32_t>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(freqIdx)) >> (32 - FFTSizeLog2);
121
+ }
122
+
123
+ // Mirrors an index about the Nyquist frequency in the DFT order
124
+ static uint32_t getDFTMirrorIndex (uint32_t idx)
125
+ {
126
+ return (FFTSize - idx) & (FFTSize - 1 );
127
+ }
128
+
129
+ // Given an index `idx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
130
+ static uint32_t getNablaMirrorIndex (uint32_t idx)
131
+ {
132
+ return getNablaIndex (getDFTMirrorIndex (getDFTIndex (idx)));
133
+ }
134
+
135
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t (WorkgroupSize) * uint32_t (ElementsPerInvocation);
138
+ };
139
+
132
140
} //namespace fft
133
141
134
142
// ----------------------------------- End Utils -----------------------------------------------
135
143
136
- template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void >
144
+ template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void >
137
145
struct FFT;
138
146
139
147
// For the FFT methods below, we assume:
@@ -153,13 +161,13 @@ struct FFT;
153
161
// * void workgroupExecutionAndMemoryBarrier();
154
162
155
163
// 2 items per invocation forward specialization
156
- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
164
+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
157
165
struct FFT<2 ,false , WorkgroupSize, Scalar, device_capabilities>
158
166
{
159
167
template<typename SharedMemoryAdaptor>
160
168
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
161
169
{
162
- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
170
+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
163
171
164
172
// Get twiddle with k = threadID mod stride, halfN = stride
165
173
hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
@@ -199,7 +207,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
199
207
}
200
208
201
209
// special last workgroup-shuffle
202
- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
210
+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
203
211
204
212
// Remember to update the accessor's state
205
213
sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -217,7 +225,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
217
225
218
226
219
227
// 2 items per invocation inverse specialization
220
- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
228
+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
221
229
struct FFT<2 ,true , WorkgroupSize, Scalar, device_capabilities>
222
230
{
223
231
template<typename SharedMemoryAdaptor>
@@ -226,7 +234,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
226
234
// Get twiddle with k = threadID mod stride, halfN = stride
227
235
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
228
236
229
- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
237
+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
230
238
}
231
239
232
240
@@ -255,7 +263,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
255
263
sharedmemAdaptor.accessor = sharedmemAccessor;
256
264
257
265
// special first workgroup-shuffle
258
- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
266
+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
259
267
260
268
// The bigger steps
261
269
[unroll]
@@ -283,7 +291,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
283
291
};
284
292
285
293
// Forward FFT
286
- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
294
+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
287
295
struct FFT<K, false , WorkgroupSize, Scalar, device_capabilities>
288
296
{
289
297
template<typename Accessor, typename SharedMemoryAccessor>
@@ -326,7 +334,7 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
326
334
};
327
335
328
336
// Inverse FFT
329
- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
337
+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
330
338
struct FFT<K, true , WorkgroupSize, Scalar, device_capabilities>
331
339
{
332
340
template<typename Accessor, typename SharedMemoryAccessor>
0 commit comments