Skip to content

Commit 6401e53

Browse files
committed
Addressed PR review comments
1 parent 2eb0ffd commit 6401e53

File tree

6 files changed

+100
-34
lines changed

6 files changed

+100
-34
lines changed

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,63 @@ namespace hlsl
1414
namespace fft
1515
{
1616

17-
// template parameter N controls the number of dimensions of the input
18-
// template parameter M controls the number of dimensions to pad up to PoT
19-
// "axes" indicates which dimensions to pad up to PoT
20-
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
21-
inline vector<uint64_t, 3> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false)
17+
18+
template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
19+
/**
20+
* @brief Returns the size of the full FFT computed, in terms of number of complex elements.
21+
*
22+
* @tparam N Number of dimensions of the signal to perform FFT on.
23+
*
24+
* @param [in] dimensions Size of the signal.
25+
* @param [in] realFFT Indicates whether the signal is real. False by default.
26+
* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. Must be less than N. 0 by default.
27+
*/
28+
inline vector<uint64_t, N> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, bool realFFT = false, uint16_t firstAxis = 0u)
2229
{
2330
vector<uint32_t, N> newDimensions = dimensions;
24-
uint16_t axisCount = 0;
25-
for (uint16_t i = 0u; i < M; i++)
31+
for (uint16_t i = 0u; i < N; i++)
2632
{
2733
newDimensions[i] = hlsl::roundUpToPoT(newDimensions[i]);
28-
if (realFFT && !axisCount++)
29-
newDimensions[i] /= 2;
3034
}
35+
if (realFFT)
36+
newDimensions[firstAxis] /= 2;
3137
return newDimensions;
3238
}
3339

34-
// template parameter N controls the number of dimensions of the input
35-
// template parameter M controls the number of dimensions we run an FFT along AND store the result
36-
// "axes" indicates which dimensions we run an FFT along AND store the result
37-
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
38-
inline uint64_t getOutputBufferSize(NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false, bool halfFloats = false)
40+
template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
41+
/**
42+
* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass.
43+
*
44+
* @tparam N Number of dimensions of the signal to perform FFT on.
45+
*
46+
* @param [in] numChannels Number of channels of the signal.
47+
* @param [in] inputDimensions Size of the signal.
48+
* @param [in] passIx Which pass the size is being computed for.
49+
* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw.
50+
* @param [in] realFFT True if the signal is real. False by default.
51+
* @param [in] halfFloats True if using half-precision floats. False by default.
52+
*/
53+
inline uint64_t getOutputBufferSize(
54+
uint32_t numChannels,
55+
NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions,
56+
uint16_t passIx,
57+
NBL_CONST_REF_ARG(vector<uint16_t, N>) axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
58+
bool realFFT = false,
59+
bool halfFloats = false
60+
)
3961
{
40-
const vector<uint64_t, 3> paddedDims = padDimensions<N, M>(inputDimensions, axes);
41-
const uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * uint64_t(numChannels);
62+
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions, realFFT, axisPassOrder[0]);
63+
vector<bool, N> axesDone = promote<vector<bool, N>, bool>(false);
64+
for (uint16_t i = 0; i <= passIx; i++)
65+
axesDone[axisPassOrder[i]] = true;
66+
const vector<uint32_t, N> passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone);
67+
uint64_t numberOfComplexElements = uint64_t(numChannels);
68+
for (uint16_t i = 0; i < N; i++)
69+
numberOfComplexElements *= uint64_t(passOutputDimension[i]);
4270
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
4371
}
4472

73+
4574
// Computes the kth element in the group of N roots of unity
4675
// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT)
4776
template<bool inverse, typename Scalar>
@@ -95,11 +124,33 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
95124
lo = x;
96125
}
97126

98-
// Bit-reverses T as a binary string of length given by Bits
99-
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_integral_v<T> && Bits <= sizeof(T) * 8)
127+
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_unsigned_v<T>&& Bits <= sizeof(T) * 8)
128+
/**
129+
* @brief Takes the binary representation of `value` as a string of `Bits` bits and returns a value of the same type resulting from reversing the string
130+
*
131+
* @tparam T Type of the value to operate on.
132+
* @tparam Bits The length of the string of bits used to represent `value`.
133+
*
134+
* @param [in] value The value to bitreverse.
135+
*/
100136
T bitReverseAs(T value)
101137
{
102-
return hlsl::bitReverse<uint32_t>(value) >> (sizeof(T) * 8 - Bits);
138+
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - Bits));
139+
}
140+
141+
template<typename T NBL_FUNC_REQUIRES(is_unsigned_v<T>)
142+
/**
143+
* @brief Takes the binary representation of `value` and returns a value of the same type resulting from reversing the string of bits as if it was `bits` long.
144+
* Keep in mind `bits` cannot exceed `8 * sizeof(T)`.
145+
*
146+
* @tparam T type of the value to operate on.
147+
*
148+
* @param [in] value The value to bitreverse.
149+
* @param [in] bits The length of the string of bits used to represent `value`.
150+
*/
151+
T bitReverseAs(T value, uint16_t bits)
152+
{
153+
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - bits));
103154
}
104155

105156
}

include/nbl/builtin/hlsl/glsl_compat/core.hlsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
#include "nbl/builtin/hlsl/cpp_compat/basic.h"
88
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
99
#include "nbl/builtin/hlsl/type_traits.hlsl"
10-
<<<<<<< HEAD
11-
#include "nbl/builtin/hlsl/bit.hlsl"
12-
=======
1310
#include "nbl/builtin/hlsl/spirv_intrinsics/glsl.std.450.hlsl"
14-
>>>>>>> master
1511

1612
namespace nbl
1713
{

include/nbl/builtin/hlsl/math/intutil.hlsl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,9 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer align(Integer alignment, Integer size,
5858
return address = nextAlignedAddr;
5959
}
6060

61+
// ------------------------------------- CPP ONLY ----------------------------------------------------------
6162
#ifndef __HLSL_VERSION
6263

63-
// Have to wait for the HLSL patch for `is_enum`. Would also have to figure out how to do it without initializer lists for HLSL use.
64-
6564
//! Get bitmask from variadic arguments passed.
6665
/*
6766
For example if you were to create bitmask for vertex attributes

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,43 @@ struct OptimalFFTParameters
3838
{
3939
uint16_t elementsPerInvocationLog2 : 8;
4040
uint16_t workgroupSizeLog2 : 8;
41+
42+
// Used to check if the parameters returned by `optimalFFTParameters` are valid
43+
bool areValid()
44+
{
45+
return elementsPerInvocationLog2 > 0 && workgroupSizeLog2 > 0;
46+
}
4147
};
4248

43-
inline OptimalFFTParameters optimalFFTParameters(const uint32_t maxWorkgroupSize, uint32_t inputArrayLength)
49+
/**
50+
* @brief Returns the best parameters (according to our metric) to run an FFT
51+
*
52+
* @param [in] maxWorkgroupSize The max number of threads that can be launched in a single workgroup
53+
* @param [in] inputArrayLength The length of the array to run an FFT on
54+
* @param [in] minSubgroupSize The smallest possible number of threads that can run in a single subgroup. 32 by default.
55+
*/
56+
inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t minSubgroupSize = 32u)
4457
{
58+
NBL_CONSTEXPR_STATIC OptimalFFTParameters invalidParameters = { 0 , 0 };
59+
4560
// Round inputArrayLength to PoT
46-
uint32_t FFTLength = 1u << (1u + findMSB(_static_cast<uint32_t>(inputArrayLength - 1u)));
61+
const uint32_t FFTLength = 1u << (1u + findMSB(_static_cast<uint32_t>(inputArrayLength - 1u)));
4762
// Round maxWorkgroupSize down to PoT
48-
uint32_t actualMaxWorkgroupSize = 1u << (findMSB(maxWorkgroupSize));
63+
const uint32_t actualMaxWorkgroupSize = 1u << (findMSB(maxWorkgroupSize));
4964
// This is the logic found in core::roundUpToPoT to get the log2
5065
const uint16_t workgroupSizeLog2 = _static_cast<uint16_t>(1u + findMSB(_static_cast<uint32_t>(min(FFTLength / 2, actualMaxWorkgroupSize) - 1u)));
51-
#ifndef __HLSL_VERSION
52-
assert((FFTLength >> workgroupSizeLog2) > 1);
53-
#endif
5466
const uint16_t elementsPerInvocationLog2 = _static_cast<uint16_t>(findMSB(FFTLength >> workgroupSizeLog2));
5567
const OptimalFFTParameters retVal = { elementsPerInvocationLog2, workgroupSizeLog2 };
56-
return retVal;
68+
69+
// Parameters are valid if the workgroup size is at most half of the FFT Length and at least as big as the smallest subgroup that can be launched
70+
if ((FFTLength >> workgroupSizeLog2) > 1 && minSubgroupSize <= (1u << workgroupSizeLog2))
71+
{
72+
return retVal;
73+
}
74+
else
75+
{
76+
return invalidParameters;
77+
}
5778
}
5879

5980
}

src/nbl/builtin/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/spirv_intrinsics/glsl.std.450
242242
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat.hlsl")
243243
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat/basic.h")
244244
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat/intrinsics.hlsl")
245-
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat/impl/intrinsics_impl.hlsl")
246245
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat/matrix.hlsl")
247246
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat/promote.hlsl")
248247
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/cpp_compat/vector.hlsl")

0 commit comments

Comments
 (0)