Skip to content

Commit eee9904

Browse files
committed
Update accessor concepts
1 parent e8f46dd commit eee9904

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ NBL_CONCEPT_BEGIN(3)
5353
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
5454
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
5555
NBL_CONCEPT_END(
56-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
57-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
56+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<complex_t<Scalar> >(index, val)), is_same_v, void))
57+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<complex_t<Scalar> >(index, val)), is_same_v, void))
5858
);
5959
#undef val
6060
#undef index

include/nbl/builtin/hlsl/fft/README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ To run an FFT, you need to call the FFT struct's static `__call` method. You do
1212
IMPORTANT: You MUST launch kernel with a workgroup size of `ConstevalParameters::WorkgroupSize`
1313

1414
* `Accessor` is an accessor to the array. It MUST provide the methods
15-
`void get(uint32_t index, inout complex_t<Scalar> value)`,
16-
`void set(uint32_t index, in complex_t<Scalar> value)`,
17-
which are hopefully self-explanatory. Furthermore, if doing an FFT with `ElementsPerInvocationLog2 > 1`, it MUST also provide a `void memoryBarrier()` method. If not accessing any type of memory during the FFT, it can be a method that does nothing. Otherwise, it must do a barrier with `AcquireRelease` semantics, with proper semantics for the type of memory it accesses. This example uses an Accessor going straight to global memory, so it requires a memory barrier. For an example of an accessor that doesn't, see the `28_FFTBloom` example, where we use preloaded accessors.
15+
`template <typename AccessType> void set(uint32_t idx, AccessType value)` and
16+
`template <typename AccessType> void get(uint32_t idx, NBL_REF_ARG(AccessType) value)`
17+
which are hopefully self-explanatory. These methods need to be able to be specialized with `AccessType` being `complex_t<Scalar>` for the FFT to work properly.
18+
Furthermore, if doing an FFT with `ElementsPerInvocationLog2 > 1`, it MUST also provide a `void memoryBarrier()` method. If not accessing any type of memory during the FFT, it can be a method that does nothing. Otherwise, it must do a barrier with `AcquireRelease` semantics, with proper semantics for the type of memory it accesses. This example uses an Accessor going straight to global memory, so it requires a memory barrier. For an example of an accessor that doesn't, see the `28_FFTBloom` example, where we use preloaded accessors.
1819

1920
* `SharedMemoryAccessor` is an accessor to a shared memory array of `uint32_t` that MUST be able to fit `WorkgroupSize` many complex elements (one per thread). When instantiating a `workgroup::fft::ConstevalParameters` struct, you can access its static member field `SharedMemoryDWORDs` that yields the amount of `uint32_t`s the shared memory array must be able to hold. It MUST provide the methods
2021
`template <typename IndexType, typename AccessType> void set(IndexType idx, AccessType value)`,
@@ -27,6 +28,8 @@ Furthermore, you must define the method `uint32_t3 nbl::hlsl::glsl::gl_WorkGroup
2728

2829
## Utils
2930

31+
### Figuring out the storage required for an FFT
32+
3033
### Figuring out compile-time parameters
3134
We provide a
3235
`workgroup::fft::optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength)`
@@ -39,7 +42,9 @@ By default, we prefer to use only 2 elements per invocation when possible, and o
3942
### Indexing
4043
We made some decisions in the design of the FFT algorithm pertaining to load/store order. In particular we wanted to keep stores linear to minimize cache misses when writing the output of an FFT. As such, the output of the FFT is not in its normal order, nor in bitreversed order (which is the standard for Cooley-Tukey implementations). Instead, it's in what we will refer to Nabla order going forward. The Nabla order allows for coalesced writes of the output.
4144

42-
The result of an FFT (either forward or inverse, assuming the input is in its natural order) will be referred to as an $\text{NFFT}$ (N for Nabla). This $\text{NFFT}$ contains the same elements as the $\text{DFT}$ (which is the properly-ordered result of an FFT) of the same signal, just in Nabla order. We provide a struct
45+
This whole discussion applies to our implementation of the forward FFT only. We have not yet implemented the same functions for the inverse FFT since we didn't have a need for it.
46+
47+
The result of a forward FFT will be referred to as an $\text{NFFT}$ (N for Nabla). This $\text{NFFT}$ contains the same elements as the $\text{DFT}$ (which is the properly-ordered result of an FFT) of the same signal, just in Nabla order. We provide a struct
4348
`FFTIndexingUtils<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>`
4449
that automatically handles the math for you in case you want to go from one order to the other. It provides the following methods:
4550

@@ -168,6 +173,7 @@ $\text{bitreverse} \circ e^{-1} = g^{-1} \circ \text{bitreverse}$
168173

169174
$F$ is called `FFTIndexingUtils::getDFTIndex` and detailed in the users section above.
170175

176+
Please note that this whole discussion and the function $F$ we worked out are only valid in the forward NFFT case. This is because we used a DIF diagram to work out the expression. An expression for the output order of the inverse NFFT should be easy to work out in the same way considering a DIT diagram. However, I did not have a use for it so I didn't bother.
171177

172178

173179
## Unpacking Rule for packed real FFTs

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ inline uint64_t getOutputBufferSize(
5656
uint16_t passIx,
5757
NBL_CONST_REF_ARG(vector<uint16_t, N>) axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
5858
bool realFFT = false,
59+
5960
bool halfFloats = false
6061
)
6162
{
@@ -70,6 +71,42 @@ inline uint64_t getOutputBufferSize(
7071
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
7172
}
7273

74+
template <uint16_t N NBL_FUNC_REQUIRES(N > 0 && N <= 4)
75+
/**
76+
* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass, when using the FFT to convolve it against a kernel.
77+
*
78+
* @tparam N Number of dimensions of the signal to perform FFT on.
79+
*
80+
* @param [in] numChannels Number of channels of the signal.
81+
* @param [in] inputDimensions Size of the signal.
82+
* @param [in] kernelDimensions Size of the kernel.
83+
* @param [in] passIx Which pass the size is being computed for.
84+
* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw.
85+
* @param [in] realFFT True if the signal is real. False by default.
86+
* @param [in] halfFloats True if using half-precision floats. False by default.
87+
*/
88+
inline uint64_t getOutputBufferSizeConvolution(
89+
uint32_t numChannels,
90+
NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions,
91+
NBL_CONST_REF_ARG(vector<uint32_t, N>) kernelDimensions,
92+
uint16_t passIx,
93+
NBL_CONST_REF_ARG(vector<uint16_t, N>) axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
94+
bool realFFT = false,
95+
96+
bool halfFloats = false
97+
)
98+
{
99+
const vector<uint32_t, N> paddedDimensions = padDimensions<N>(inputDimensions + kernelDimensions, realFFT, axisPassOrder[0]);
100+
vector<bool, N> axesDone = promote<vector<bool, N>, bool>(false);
101+
for (uint16_t i = 0; i <= passIx; i++)
102+
axesDone[axisPassOrder[i]] = true;
103+
const vector<uint32_t, N> passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone);
104+
uint64_t numberOfComplexElements = uint64_t(numChannels);
105+
for (uint16_t i = 0; i < N; i++)
106+
numberOfComplexElements *= uint64_t(passOutputDimension[i]);
107+
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
108+
}
109+
73110

74111
// Computes the kth element in the group of N roots of unity
75112
// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT)

0 commit comments

Comments
 (0)