Skip to content

Commit 0f8bcac

Browse files
Merge pull request #818 from Devsh-Graphics-Programming/more_fft_utils
FFT Fixes
2 parents 6fa23b1 + 5518e01 commit 0f8bcac

File tree

4 files changed

+69
-26
lines changed

4 files changed

+69
-26
lines changed

include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ NBL_CONCEPT_BEGIN(3)
5353
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
5454
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
5555
NBL_CONCEPT_END(
56-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
57-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
56+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<complex_t<Scalar> >(index, val)), is_same_v, void))
57+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<complex_t<Scalar> >(index, val)), is_same_v, void))
5858
);
5959
#undef val
6060
#undef index

include/nbl/builtin/hlsl/fft/README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ To run an FFT, you need to call the FFT struct's static `__call` method. You do
1212
IMPORTANT: You MUST launch kernel with a workgroup size of `ConstevalParameters::WorkgroupSize`
1313

1414
* `Accessor` is an accessor to the array. It MUST provide the methods
15-
`void get(uint32_t index, inout complex_t<Scalar> value)`,
16-
`void set(uint32_t index, in complex_t<Scalar> value)`,
17-
which are hopefully self-explanatory. Furthermore, if doing an FFT with `ElementsPerInvocationLog2 > 1`, it MUST also provide a `void memoryBarrier()` method. If not accessing any type of memory during the FFT, it can be a method that does nothing. Otherwise, it must do a barrier with `AcquireRelease` semantics, with proper semantics for the type of memory it accesses. This example uses an Accessor going straight to global memory, so it requires a memory barrier. For an example of an accessor that doesn't, see the `28_FFTBloom` example, where we use preloaded accessors.
15+
`template <typename AccessType> void set(uint32_t idx, AccessType value)` and
16+
`template <typename AccessType> void get(uint32_t idx, NBL_REF_ARG(AccessType) value)`
17+
which are hopefully self-explanatory. These methods need to be able to be specialized with `AccessType` being `complex_t<Scalar>` for the FFT to work properly.
18+
Furthermore, if doing an FFT with `ElementsPerInvocationLog2 > 1`, it MUST also provide a `void memoryBarrier()` method. If not accessing any type of memory during the FFT, it can be a method that does nothing. Otherwise, it must do a barrier with `AcquireRelease` semantics, with proper semantics for the type of memory it accesses. This example uses an Accessor going straight to global memory, so it requires a memory barrier. For an example of an accessor that doesn't, see the `28_FFTBloom` example, where we use preloaded accessors.
1819

1920
* `SharedMemoryAccessor` is an accessor to a shared memory array of `uint32_t` that MUST be able to fit `WorkgroupSize` many complex elements (one per thread). When instantiating a `workgroup::fft::ConstevalParameters` struct, you can access its static member field `SharedMemoryDWORDs` that yields the amount of `uint32_t`s the shared memory array must be able to hold. It MUST provide the methods
2021
`template <typename IndexType, typename AccessType> void set(IndexType idx, AccessType value)`,
@@ -27,6 +28,8 @@ Furthermore, you must define the method `uint32_t3 nbl::hlsl::glsl::gl_WorkGroup
2728

2829
## Utils
2930

31+
### Figuring out the storage required for an FFT
32+
3033
### Figuring out compile-time parameters
3134
We provide a
3235
`workgroup::fft::optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength)`
@@ -39,7 +42,9 @@ By default, we prefer to use only 2 elements per invocation when possible, and o
3942
### Indexing
4043
We made some decisions in the design of the FFT algorithm pertaining to load/store order. In particular we wanted to keep stores linear to minimize cache misses when writing the output of an FFT. As such, the output of the FFT is not in its normal order, nor in bitreversed order (which is the standard for Cooley-Tukey implementations). Instead, it's in what we will refer to Nabla order going forward. The Nabla order allows for coalesced writes of the output.
4144

42-
The result of an FFT (either forward or inverse, assuming the input is in its natural order) will be referred to as an $\text{NFFT}$ (N for Nabla). This $\text{NFFT}$ contains the same elements as the $\text{DFT}$ (which is the properly-ordered result of an FFT) of the same signal, just in Nabla order. We provide a struct
45+
This whole discussion applies to our implementation of the forward FFT only. We have not yet implemented the same functions for the inverse FFT since we didn't have a need for it.
46+
47+
The result of a forward FFT will be referred to as an $\text{NFFT}$ (N for Nabla). This $\text{NFFT}$ contains the same elements as the $\text{DFT}$ (which is the properly-ordered result of an FFT) of the same signal, just in Nabla order. We provide a struct
4348
`FFTIndexingUtils<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>`
4449
that automatically handles the math for you in case you want to go from one order to the other. It provides the following methods:
4550

@@ -168,6 +173,7 @@ $\text{bitreverse} \circ e^{-1} = g^{-1} \circ \text{bitreverse}$
168173

169174
$F$ is called `FFTIndexingUtils::getDFTIndex` and detailed in the users section above.
170175

176+
Please note that this whole discussion and the function $F$ we worked out are only valid in the forward NFFT case. This is because we used a DIF diagram to work out the expression. An expression for the output order of the inverse NFFT should be easy to work out in the same way considering a DIT diagram. However, I did not have a use for it so I didn't bother.
171177

172178

173179
## Unpacking Rule for packed real FFTs

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@ namespace fft
1717

1818
template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
1919
/**
20-
* @brief Returns the size of the full FFT computed, in terms of number of complex elements.
20+
* @brief Returns the size of the full FFT computed, in terms of number of complex elements. If the signal is real, you MUST provide a valid value for `firstAxis`
2121
*
2222
* @tparam N Number of dimensions of the signal to perform FFT on.
2323
*
2424
* @param [in] dimensions Size of the signal.
25-
* @param [in] realFFT Indicates whether the signal is real. False by default.
26-
* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. Must be less than N. 0 by default.
25+
* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals, in which case it must be less than N. N by default.
2726
*/
28-
inline vector<uint64_t, N> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, bool realFFT = false, uint16_t firstAxis = 0u)
27+
inline vector<uint64_t, N> padDimensions(vector<uint32_t, N> dimensions, uint16_t firstAxis = N)
2928
{
3029
vector<uint32_t, N> newDimensions = dimensions;
3130
for (uint16_t i = 0u; i < N; i++)
3231
{
3332
newDimensions[i] = hlsl::roundUpToPoT(newDimensions[i]);
3433
}
35-
if (realFFT)
34+
if (firstAxis < N)
3635
newDimensions[firstAxis] /= 2;
3736
return newDimensions;
3837
}
@@ -52,14 +51,50 @@ template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
5251
*/
5352
inline uint64_t getOutputBufferSize(
5453
uint32_t numChannels,
55-
NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions,
54+
vector<uint32_t, N> inputDimensions,
5655
uint16_t passIx,
57-
NBL_CONST_REF_ARG(vector<uint16_t, N>) axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
56+
vector<uint16_t, N> axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
5857
bool realFFT = false,
5958
bool halfFloats = false
6059
)
6160
{
62-
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions, realFFT, axisPassOrder[0]);
61+
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions, realFFT ? axisPassOrder[0] : N);
62+
vector<bool, N> axesDone = promote<vector<bool, N>, bool>(false);
63+
for (uint16_t i = 0; i <= passIx; i++)
64+
axesDone[axisPassOrder[i]] = true;
65+
const vector<uint32_t, N> passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone);
66+
uint64_t numberOfComplexElements = uint64_t(numChannels);
67+
for (uint16_t i = 0; i < N; i++)
68+
numberOfComplexElements *= uint64_t(passOutputDimension[i]);
69+
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
70+
}
71+
72+
template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
73+
/**
74+
* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass, when using the FFT to convolve it against a kernel.
75+
*
76+
* @tparam N Number of dimensions of the signal to perform FFT on.
77+
*
78+
* @param [in] numChannels Number of channels of the signal.
79+
* @param [in] inputDimensions Size of the signal.
80+
* @param [in] kernelDimensions Size of the kernel.
81+
* @param [in] passIx Which pass the size is being computed for.
82+
* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw.
83+
* @param [in] realFFT True if the signal is real. False by default.
84+
* @param [in] halfFloats True if using half-precision floats. False by default.
85+
*/
86+
inline uint64_t getOutputBufferSizeConvolution(
87+
uint32_t numChannels,
88+
vector<uint32_t, N> inputDimensions,
89+
vector<uint32_t, N> kernelDimensions,
90+
uint16_t passIx,
91+
vector<uint16_t, N> axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
92+
bool realFFT = false,
93+
94+
bool halfFloats = false
95+
)
96+
{
97+
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions + kernelDimensions, realFFT ? axisPassOrder[0] : N);
6398
vector<bool, N> axesDone = promote<vector<bool, N>, bool>(false);
6499
for (uint16_t i = 0; i <= passIx; i++)
65100
axesDone[axisPassOrder[i]] = true;
@@ -89,7 +124,7 @@ complex_t<Scalar> twiddle(uint32_t k, uint32_t halfN)
89124
template<bool inverse, typename Scalar>
90125
struct DIX
91126
{
92-
static void radix2(NBL_CONST_REF_ARG(complex_t<Scalar>) twiddle, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
127+
static void radix2(complex_t<Scalar> twiddle, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
93128
{
94129
plus_assign< complex_t<Scalar> > plusAss;
95130
//Decimation in time - inverse

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,30 @@ struct OptimalFFTParameters
5454
* @param [in] inputArrayLength The length of the array to run an FFT on
5555
* @param [in] minSubgroupSize The smallest possible number of threads that can run in a single subgroup. 32 by default.
5656
*/
57-
inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t minSubgroupSize = 32u)
57+
inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t minSubgroupSize)
5858
{
5959
NBL_CONSTEXPR_STATIC OptimalFFTParameters invalidParameters = { 0 , 0 };
6060

61+
if (minSubgroupSize < 4 || maxWorkgroupSize < minSubgroupSize || inputArrayLength <= minSubgroupSize)
62+
return invalidParameters;
63+
6164
// Round inputArrayLength to PoT
62-
const uint32_t FFTLength = 1u << (1u + findMSB(_static_cast<uint32_t>(inputArrayLength - 1u)));
65+
const uint32_t FFTLength = hlsl::roundUpToPoT(inputArrayLength);
6366
// Round maxWorkgroupSize down to PoT
64-
const uint32_t actualMaxWorkgroupSize = 1u << (findMSB(maxWorkgroupSize));
65-
// This is the logic found in core::roundUpToPoT to get the log2
67+
const uint32_t actualMaxWorkgroupSize = hlsl::roundDownToPoT(maxWorkgroupSize);
68+
// This is the logic found in hlsl::roundUpToPoT to get the log2
6669
const uint16_t workgroupSizeLog2 = _static_cast<uint16_t>(1u + findMSB(_static_cast<uint32_t>(min(FFTLength / 2, actualMaxWorkgroupSize) - 1u)));
67-
const uint16_t elementsPerInvocationLog2 = _static_cast<uint16_t>(findMSB(FFTLength >> workgroupSizeLog2));
68-
const OptimalFFTParameters retVal = { elementsPerInvocationLog2, workgroupSizeLog2 };
6970

7071
// Parameters are valid if the workgroup size is at most half of the FFT Length and at least as big as the smallest subgroup that can be launched
71-
if ((FFTLength >> workgroupSizeLog2) > 1 && minSubgroupSize <= (1u << workgroupSizeLog2))
72-
{
73-
return retVal;
74-
}
75-
else
72+
if ((FFTLength >> workgroupSizeLog2) <= 1 || minSubgroupSize > (1u << workgroupSizeLog2))
7673
{
7774
return invalidParameters;
7875
}
76+
77+
const uint16_t elementsPerInvocationLog2 = _static_cast<uint16_t>(findMSB(FFTLength >> workgroupSizeLog2));
78+
const OptimalFFTParameters retVal = { elementsPerInvocationLog2, workgroupSizeLog2 };
79+
80+
return retVal;
7981
}
8082

8183
namespace impl

0 commit comments

Comments
 (0)