Skip to content

Commit 5518e01

Browse files
committed
Address all PR comments
1 parent eee9904 commit 5518e01

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@ namespace fft
1717

1818
template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
1919
/**
20-
* @brief Returns the size of the full FFT computed, in terms of number of complex elements.
20+
* @brief Returns the size of the full FFT computed, in terms of number of complex elements. If the signal is real, you MUST provide a valid value for `firstAxis`
2121
*
2222
* @tparam N Number of dimensions of the signal to perform FFT on.
2323
*
2424
* @param [in] dimensions Size of the signal.
25-
* @param [in] realFFT Indicates whether the signal is real. False by default.
26-
* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. Must be less than N. 0 by default.
25+
* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals, in which case it must be less than N. N by default.
2726
*/
28-
inline vector<uint64_t, N> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, bool realFFT = false, uint16_t firstAxis = 0u)
27+
inline vector<uint64_t, N> padDimensions(vector<uint32_t, N> dimensions, uint16_t firstAxis = N)
2928
{
3029
vector<uint32_t, N> newDimensions = dimensions;
3130
for (uint16_t i = 0u; i < N; i++)
3231
{
3332
newDimensions[i] = hlsl::roundUpToPoT(newDimensions[i]);
3433
}
35-
if (realFFT)
34+
if (firstAxis < N)
3635
newDimensions[firstAxis] /= 2;
3736
return newDimensions;
3837
}
@@ -52,15 +51,14 @@ template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
5251
*/
5352
inline uint64_t getOutputBufferSize(
5453
uint32_t numChannels,
55-
NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions,
54+
vector<uint32_t, N> inputDimensions,
5655
uint16_t passIx,
57-
NBL_CONST_REF_ARG(vector<uint16_t, N>) axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
56+
vector<uint16_t, N> axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
5857
bool realFFT = false,
59-
6058
bool halfFloats = false
6159
)
6260
{
63-
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions, realFFT, axisPassOrder[0]);
61+
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions, realFFT ? axisPassOrder[0] : N);
6462
vector<bool, N> axesDone = promote<vector<bool, N>, bool>(false);
6563
for (uint16_t i = 0; i <= passIx; i++)
6664
axesDone[axisPassOrder[i]] = true;
@@ -87,16 +85,16 @@ template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
8785
*/
8886
inline uint64_t getOutputBufferSizeConvolution(
8987
uint32_t numChannels,
90-
NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions,
91-
NBL_CONST_REF_ARG(vector<uint32_t, N>) kernelDimensions,
88+
vector<uint32_t, N> inputDimensions,
89+
vector<uint32_t, N> kernelDimensions,
9290
uint16_t passIx,
93-
NBL_CONST_REF_ARG(vector<uint16_t, N>) axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
91+
vector<uint16_t, N> axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
9492
bool realFFT = false,
9593

9694
bool halfFloats = false
9795
)
9896
{
99-
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions + kernelDimensions, realFFT, axisPassOrder[0]);
97+
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions + kernelDimensions, realFFT ? axisPassOrder[0] : N);
10098
vector<bool, N> axesDone = promote<vector<bool, N>, bool>(false);
10199
for (uint16_t i = 0; i <= passIx; i++)
102100
axesDone[axisPassOrder[i]] = true;
@@ -126,7 +124,7 @@ complex_t<Scalar> twiddle(uint32_t k, uint32_t halfN)
126124
template<bool inverse, typename Scalar>
127125
struct DIX
128126
{
129-
static void radix2(NBL_CONST_REF_ARG(complex_t<Scalar>) twiddle, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
127+
static void radix2(complex_t<Scalar> twiddle, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
130128
{
131129
plus_assign< complex_t<Scalar> > plusAss;
132130
//Decimation in time - inverse

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,30 @@ struct OptimalFFTParameters
5454
* @param [in] inputArrayLength The length of the array to run an FFT on
5555
* @param [in] minSubgroupSize The smallest possible number of threads that can run in a single subgroup. 32 by default.
5656
*/
57-
inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t minSubgroupSize = 32u)
57+
inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t minSubgroupSize)
5858
{
5959
NBL_CONSTEXPR_STATIC OptimalFFTParameters invalidParameters = { 0 , 0 };
6060

61+
if (minSubgroupSize < 4 || maxWorkgroupSize < minSubgroupSize || inputArrayLength <= minSubgroupSize)
62+
return invalidParameters;
63+
6164
// Round inputArrayLength to PoT
62-
const uint32_t FFTLength = 1u << (1u + findMSB(_static_cast<uint32_t>(inputArrayLength - 1u)));
65+
const uint32_t FFTLength = hlsl::roundUpToPoT(inputArrayLength);
6366
// Round maxWorkgroupSize down to PoT
64-
const uint32_t actualMaxWorkgroupSize = 1u << (findMSB(maxWorkgroupSize));
65-
// This is the logic found in core::roundUpToPoT to get the log2
67+
const uint32_t actualMaxWorkgroupSize = hlsl::roundDownToPoT(maxWorkgroupSize);
68+
// This is the logic found in hlsl::roundUpToPoT to get the log2
6669
const uint16_t workgroupSizeLog2 = _static_cast<uint16_t>(1u + findMSB(_static_cast<uint32_t>(min(FFTLength / 2, actualMaxWorkgroupSize) - 1u)));
67-
const uint16_t elementsPerInvocationLog2 = _static_cast<uint16_t>(findMSB(FFTLength >> workgroupSizeLog2));
68-
const OptimalFFTParameters retVal = { elementsPerInvocationLog2, workgroupSizeLog2 };
6970

7071
// Parameters are valid if the workgroup size is at most half of the FFT Length and at least as big as the smallest subgroup that can be launched
71-
if ((FFTLength >> workgroupSizeLog2) > 1 && minSubgroupSize <= (1u << workgroupSizeLog2))
72-
{
73-
return retVal;
74-
}
75-
else
72+
if ((FFTLength >> workgroupSizeLog2) <= 1 || minSubgroupSize > (1u << workgroupSizeLog2))
7673
{
7774
return invalidParameters;
7875
}
76+
77+
const uint16_t elementsPerInvocationLog2 = _static_cast<uint16_t>(findMSB(FFTLength >> workgroupSizeLog2));
78+
const OptimalFFTParameters retVal = { elementsPerInvocationLog2, workgroupSizeLog2 };
79+
80+
return retVal;
7981
}
8082

8183
namespace impl

0 commit comments

Comments
 (0)