@@ -14,34 +14,63 @@ namespace hlsl
14
14
namespace fft
15
15
{
16
16
17
- // template parameter N controls the number of dimensions of the input
18
- // template parameter M controls the number of dimensions to pad up to PoT
19
- // "axes" indicates which dimensions to pad up to PoT
20
- template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES (M <= N)
21
- inline vector <uint64_t, 3 > padDimensions (NBL_CONST_REF_ARG (vector <uint32_t, N>) dimensions, NBL_CONST_REF_ARG (vector <uint16_t, M>) axes, bool realFFT = false )
17
+
18
+ template <uint16_t N NBL_FUNC_REQUIRES (N > 0 && N <= 4 )
19
+ /**
20
+ * @brief Returns the size of the full FFT computed, in terms of number of complex elements.
21
+ *
22
+ * @tparam N Number of dimensions of the signal to perform FFT on.
23
+ *
24
+ * @param [in] dimensions Size of the signal.
25
+ * @param [in] realFFT Indicates whether the signal is real. False by default.
26
+ * @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. Must be less than N. 0 by default.
27
+ */
28
+ inline vector <uint64_t, N> padDimensions (NBL_CONST_REF_ARG (vector <uint32_t, N>) dimensions, bool realFFT = false , uint16_t firstAxis = 0u)
22
29
{
23
30
vector <uint32_t, N> newDimensions = dimensions;
24
- uint16_t axisCount = 0 ;
25
- for (uint16_t i = 0u; i < M; i++)
31
+ for (uint16_t i = 0u; i < N; i++)
26
32
{
27
33
newDimensions[i] = hlsl::roundUpToPoT (newDimensions[i]);
28
- if (realFFT && !axisCount++)
29
- newDimensions[i] /= 2 ;
30
34
}
35
+ if (realFFT)
36
+ newDimensions[firstAxis] /= 2 ;
31
37
return newDimensions;
32
38
}
33
39
34
- // template parameter N controls the number of dimensions of the input
35
- // template parameter M controls the number of dimensions we run an FFT along AND store the result
36
- // "axes" indicates which dimensions we run an FFT along AND store the result
37
- template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES (M <= N)
38
- inline uint64_t getOutputBufferSize (NBL_CONST_REF_ARG (vector <uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG (vector <uint16_t, M>) axes, bool realFFT = false , bool halfFloats = false )
40
+ template <uint16_t N NBL_FUNC_REQUIRES (N > 0 && N <= 4 )
41
+ /**
42
+ * @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass.
43
+ *
44
+ * @tparam N Number of dimensions of the signal to perform FFT on.
45
+ *
46
+ * @param [in] numChannels Number of channels of the signal.
47
+ * @param [in] inputDimensions Size of the signal.
48
+ * @param [in] passIx Which pass the size is being computed for.
49
+ * @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw.
50
+ * @param [in] realFFT True if the signal is real. False by default.
51
+ * @param [in] halfFloats True if using half-precision floats. False by default.
52
+ */
53
+ inline uint64_t getOutputBufferSize (
54
+ uint32_t numChannels,
55
+ NBL_CONST_REF_ARG (vector <uint32_t, N>) inputDimensions,
56
+ uint16_t passIx,
57
+ NBL_CONST_REF_ARG (vector <uint16_t, N>) axisPassOrder = _static_cast<vector <uint16_t, N> >(uint16_t4 (0 , 1 , 2 , 3 )),
58
+ bool realFFT = false ,
59
+ bool halfFloats = false
60
+ )
39
61
{
40
- const vector <uint64_t, 3 > paddedDims = padDimensions<N, M>(inputDimensions, axes);
41
- const uint64_t numberOfComplexElements = paddedDims[0 ] * paddedDims[1 ] * paddedDims[2 ] * uint64_t (numChannels);
62
+ const vector <uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions, realFFT, axisPassOrder[0 ]);
63
+ vector <bool , N> axesDone = promote<vector <bool , N>, bool >(false );
64
+ for (uint16_t i = 0 ; i <= passIx; i++)
65
+ axesDone[axisPassOrder[i]] = true ;
66
+ const vector <uint32_t, N> passOutputDimension = lerp (inputDimensions, paddedDimensions, axesDone);
67
+ uint64_t numberOfComplexElements = uint64_t (numChannels);
68
+ for (uint16_t i = 0 ; i < N; i++)
69
+ numberOfComplexElements *= uint64_t (passOutputDimension[i]);
42
70
return numberOfComplexElements * (halfFloats ? sizeof (complex_t<float16_t>) : sizeof (complex_t<float32_t>));
43
71
}
44
72
73
+
45
74
// Computes the kth element in the group of N roots of unity
46
75
// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT)
47
76
template<bool inverse, typename Scalar>
@@ -95,11 +124,33 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
95
124
lo = x;
96
125
}
97
126
98
- // Bit-reverses T as a binary string of length given by Bits
99
- template<typename T, uint16_t Bits NBL_FUNC_REQUIRES (is_integral_v<T> && Bits <= sizeof (T) * 8 )
127
+ template<typename T, uint16_t Bits NBL_FUNC_REQUIRES (is_unsigned_v<T>&& Bits <= sizeof (T) * 8 )
128
+ /**
129
+ * @brief Takes the binary representation of `value` as a string of `Bits` bits and returns a value of the same type resulting from reversing the string
130
+ *
131
+ * @tparam T Type of the value to operate on.
132
+ * @tparam Bits The length of the string of bits used to represent `value`.
133
+ *
134
+ * @param [in] value The value to bitreverse.
135
+ */
100
136
T bitReverseAs (T value)
101
137
{
102
- return hlsl::bitReverse<uint32_t>(value) >> (sizeof (T) * 8 - Bits);
138
+ return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof (T) * 8 - Bits));
139
+ }
140
+
141
+ template<typename T NBL_FUNC_REQUIRES (is_unsigned_v<T>)
142
+ /**
143
+ * @brief Takes the binary representation of `value` and returns a value of the same type resulting from reversing the string of bits as if it was `bits` long.
144
+ * Keep in mind `bits` cannot exceed `8 * sizeof(T)`.
145
+ *
146
+ * @tparam T type of the value to operate on.
147
+ *
148
+ * @param [in] value The value to bitreverse.
149
+ * @param [in] bits The length of the string of bits used to represent `value`.
150
+ */
151
+ T bitReverseAs (T value, uint16_t bits)
152
+ {
153
+ return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof (T) * 8 - bits));
103
154
}
104
155
105
156
}
0 commit comments