Skip to content

Commit 2782d8d

Browse files
committed
FFT readme update and accessor concept change to remove memory barriers
1 parent ae5dbad commit 2782d8d

File tree

4 files changed

+62
-50
lines changed

4 files changed

+62
-50
lines changed

include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ NBL_CONCEPT_END(
4242
// * void get(uint32_t index, inout complex_t<Scalar> value);
4343
// * void set(uint32_t index, in complex_t<Scalar> value);
4444

45-
#define NBL_CONCEPT_NAME SmallFFTAccessor
45+
#define NBL_CONCEPT_NAME FFTAccessor
4646
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
4747
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
4848
#define NBL_CONCEPT_PARAM_0 (accessor, T)
@@ -61,24 +61,6 @@ NBL_CONCEPT_END(
6161
#undef accessor
6262
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
6363

64-
65-
// The Accessor MUST provide the following methods:
66-
// * void get(uint32_t index, inout complex_t<Scalar> value);
67-
// * void set(uint32_t index, in complex_t<Scalar> value);
68-
// * void memoryBarrier();
69-
70-
#define NBL_CONCEPT_NAME FFTAccessor
71-
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
72-
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
73-
#define NBL_CONCEPT_PARAM_0 (accessor, T)
74-
NBL_CONCEPT_BEGIN(1)
75-
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
76-
NBL_CONCEPT_END(
77-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))
78-
) && SmallFFTAccessor<T, Scalar>;
79-
#undef accessor
80-
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
81-
8264
}
8365
}
8466
}

include/nbl/builtin/hlsl/fft/README.md

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,30 @@ To run an FFT, you need to call the FFT struct's static `__call` method. You do
88

99
* `Inverse` indicates whether you're running a forward or an inverse FFT
1010

11-
* `ConstevalParameters` is a struct created from three compile-time constants: `ElementsPerInvocationLog2`, `WorkgroupSizeLog2` and `Scalar`. `Scalar` is just the scalar type for the complex numbers involved, `WorkgroupSizeLog2` is self-explanatory, and `ElementsPerInvocationLog2` is the (log of) the number of elements of the array each thread is tasked with computing, with the total ElementsPerInvocation being the length `FFTLength` of the array to perform an FFT on (remember it must be PoT) divided by the workgroup size used. This makes both `ElementsPerInvocation` and `WorkgroupSize` be PoT.
11+
* `ConstevalParameters` is a struct created from three compile-time constants: `ElementsPerInvocationLog2`, `WorkgroupSizeLog2` and `Scalar`. `Scalar` is just the scalar type for the complex numbers involved, `WorkgroupSizeLog2` is self-explanatory, and `ElementsPerInvocationLog2` is the (log of) the number of elements of the array each thread is tasked with computing, with the total `ElementsPerInvocation` being the length `FFTLength` of the array to perform an FFT on (remember it must be PoT) divided by the workgroup size used. This makes both `ElementsPerInvocation` and `WorkgroupSize` be PoT.
1212
IMPORTANT: You MUST launch kernel with a workgroup size of `ConstevalParameters::WorkgroupSize`
1313

14-
* `Accessor` is an accessor to the array. It MUST provide the methods
15-
`template <typename AccessType> void set(uint32_t idx, AccessType value)` and
16-
`template <typename AccessType> void get(uint32_t idx, NBL_REF_ARG(AccessType) value)`
17-
which are hopefully self-explanatory. These methods need to be able to be specialized with `AccessType` being `complex_t<Scalar>` for the FFT to work properly.
18-
Furthermore, if doing an FFT with `ElementsPerInvocationLog2 > 1`, it MUST also provide a `void memoryBarrier()` method. If not accessing any type of memory during the FFT, it can be a method that does nothing. Otherwise, it must do a barrier with `AcquireRelease` semantics, with proper semantics for the type of memory it accesses. This example uses an Accessor going straight to global memory, so it requires a memory barrier. For an example of an accessor that doesn't, see the `28_FFTBloom` example, where we use preloaded accessors.
14+
* `Accessor` is an accessor to the array. It MUST provide the methods
15+
```cpp
16+
template <typename AccessType>
17+
void set(uint32_t idx, AccessType value);
18+
19+
template <typename AccessType>
20+
void get(uint32_t idx, NBL_REF_ARG(AccessType) value);
21+
```
22+
These methods need to be able to be specialized with `AccessType` being `complex_t<Scalar>` for the FFT to work properly.
23+
Furthermore, if doing an FFT with `ElementsPerInvocationLog2 > 1`, it MUST also provide a `void memoryBarrier()` method. If not accessing any type of memory during the FFT, it can be a method that does nothing. Otherwise, it must do a barrier with `AcquireRelease` semantics, with proper semantics for the type of memory it accesses.
1924
20-
* `SharedMemoryAccessor` is an accessor to a shared memory array of `uint32_t` that MUST be able to fit `WorkgroupSize` many complex elements (one per thread). When instantiating a `workgroup::fft::ConstevalParameters` struct, you can access its static member field `SharedMemoryDWORDs` that yields the amount of `uint32_t`s the shared memory array must be able to hold. It MUST provide the methods
25+
* `SharedMemoryAccessor` is an accessor to a shared memory array of `uint32_t` that MUST be able to fit `WorkgroupSize` many complex elements (one per thread). When instantiating a `workgroup::fft::ConstevalParameters` struct, you can access its static member field `SharedMemoryDWORDs` that yields the amount of `uint32_t`s the shared memory array must be able to hold. It MUST provide the methods
26+
```cpp
27+
template <typename IndexType, typename AccessType>
28+
void set(IndexType idx, AccessType value);
29+
30+
template <typename IndexType, typename AccessType>
31+
void get(IndexType idx, NBL_REF_ARG(AccessType) value);
32+
33+
void workgroupExecutionAndMemoryBarrier();
34+
```
2135
`template <typename IndexType, typename AccessType> void set(IndexType idx, AccessType value)`,
2236
`template <typename IndexType, typename AccessType> void get(IndexType idx, NBL_REF_ARG(AccessType) value)` and
2337
`void workgroupExecutionAndMemoryBarrier()`.
@@ -31,46 +45,61 @@ Furthermore, you must define the method `uint32_t3 nbl::hlsl::glsl::gl_WorkGroup
3145
### Figuring out the storage required for an FFT
3246
We provide the functions
3347
```cpp
34-
uint64_t fft::getOutputBufferSize(
48+
template <uint16_t N>
49+
uint64_t getOutputBufferSize(
3550
uint32_t numChannels,
3651
vector<uint32_t, N> inputDimensions,
3752
uint16_t passIx,
38-
vector<uint16_t, N> axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
39-
bool realFFT = false,
40-
bool halfFloats = false);
41-
53+
vector<uint16_t, N> axisPassOrder,
54+
bool realFFT,
55+
bool halfFloats
56+
)
57+
template <uint16_t N>
4258
uint64_t getOutputBufferSizeConvolution(
4359
uint32_t numChannels,
4460
vector<uint32_t, N> inputDimensions,
4561
vector<uint32_t, N> kernelDimensions,
4662
uint16_t passIx,
47-
vector<uint16_t, N> axisPassOrder = _static_cast<vector<uint16_t, N> >(uint16_t4(0, 1, 2, 3)),
48-
bool realFFT = false,
49-
bool halfFloats = false
63+
vector<uint16_t, N> axisPassOrder,
64+
bool realFFT,
65+
bool halfFloats
5066
)
5167
```
52-
which yield the size (in bytes) required to store the result of an FFT of a signal with `numChannels` channels of size `inputDImensions` after running the FFT along the axis `axisPassOrder[passIx]` (if you don't
53-
provide this order it's assumed to be `xyzw`). It furthermore takes an argument `realFFT` which if true means you are doing an FFT on a real signal AND you want to store the output of the FFT along the first axis
68+
in the `fft` namespace which yield the size (in bytes) required to store the result of an FFT of a signal with `numChannels` channels of size `inputDImensions` after running the FFT along the axis `axisPassOrder[passIx]` (if you don't
69+
provide this order it's assumed to be `xyzw`). This assumes that you don't run or store any unnecessary FFTs, since with wrapping modes it's always possible to recover the result in the padding area (sampling outside of $[0,1)$ along some axis).
70+
71+
It furthermore takes an argument `realFFT` which if true means you are doing an FFT on a real signal AND you want to store the output of the FFT along the first axis
5472
in a compact manner (knowing that FFTs of real signals are conjugate-symmetric). By default it assumes your complex numbers have `float32_t` scalars, `halfFloats` set to true means you're using `float16_t` scalars.
5573

5674
`getOutputBufferSizeConvolution` furthermore takes a `kernelDimensions` argument. When convolving a signal against a kernel, the FFT has some extra padding to consider, so these methods are different.
5775

76+
### Figuring out amount of Shared Memory necessary to run an FFT
77+
After instantiating it, we can access the `constexpr uint32_t ConstevalParameters::SharedMemoryDWORDs` that tells us the size (in number of `uint32_t`s) that the shared memory array must have.
78+
5879
### Figuring out compile-time parameters
59-
We provide a
60-
`workgroup::fft::optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength)`
61-
function, which yields possible values for `ElementsPerInvocationLog2` and `WorkgroupSizeLog2` you might want to use to instantiate a `ConstevalParameters` struct.
80+
We provide a
81+
```cpp
82+
OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength);
83+
```
84+
function in the `workgroup::fft` namespace, which yields possible values for `ElementsPerInvocationLog2` and `WorkgroupSizeLog2` you might want to use to instantiate a `ConstevalParameters` struct, packed in a `OptimalFFTParameters` struct.
85+
86+
By default, we prefer to use only 2 elements per invocation when possible, and only use more if
87+
$2 \cdot \text{maxWorkgroupSize} < \text{inputArrayLength}$. This is because using more elements per thread either results in more accesses to the array via the `Accessor` or, if using preloaded accessors, it results in lower occupancy.
6288
63-
By default, we prefer to use only 2 elements per invocation when possible, and only use more if $2 \cdot \text{maxWorkgroupSize} < \text{inputArrayLength}$. This is because using more elements per thread either results in more accesses to the array via the `Accessor` or, if using preloaded accessors, it results in lower occupancy.
89+
`inputArrayLength` can be arbitrary, but please do note that the parameters returned will be for running an FFT on an array of length `roundUpToPoT(inputArrayLength)` and YOU are responsible for padding your data up to that size.
6490
65-
`inputArrayLength` can be arbitrary, but please do note that the parameters returned will be for running an FFT on an array of length `roundUpToPoT(inputArrayLength)` and YOU are responsible for padding your data up to that size. You are, of course, free to choose whatever parameters are better for your use case, this is just a default.
91+
You are, of course, free to choose whatever `ConstevalParameters` are better for your use case, this is just a default.
6692
6793
### Indexing
68-
We made some decisions in the design of the FFT algorithm pertaining to load/store order. In particular we wanted to keep stores linear to minimize cache misses when writing the output of an FFT. As such, the output of the FFT is not in its normal order, nor in bitreversed order (which is the standard for Cooley-Tukey implementations). Instead, it's in what we will refer to Nabla order going forward. The Nabla order allows for coalesced writes of the output.
94+
We made some decisions in the design of the FFT algorithm pertaining to load/store order. In particular we wanted to keep stores linear to minimize cache misses when writing the output of an FFT. As such, the output of the FFT is not in its normal order, nor in bitreversed order (which is the standard for Cooley-Tukey implementations). Instead, it's in what we will refer to Nabla order going forward. The Nabla order allows for coalesced writes of the output, and is essentially the "natural order" of the output of our algorithm, meaning it's the order of the output that doesn't require incurring in any extra ordering operations.
6995
7096
This whole discussion applies to our implementation of the forward FFT only. We have not yet implemented the same functions for the inverse FFT since we didn't have a need for it.
7197
7298
The result of a forward FFT will be referred to as an $\text{NFFT}$ (N for Nabla). This $\text{NFFT}$ contains the same elements as the $\text{DFT}$ (which is the properly-ordered result of an FFT) of the same signal, just in Nabla order. We provide a struct
73-
`FFTIndexingUtils<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>`
99+
```cpp
100+
template <uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
101+
struct FFTIndexingUtils;
102+
```
74103
that automatically handles the math for you in case you want to go from one order to the other. It provides the following methods:
75104

76105
* `uint32_t getDFTIndex(uint32_t outputIdx)`: given an index $\text{outputIdx}$ into the $\text{NFFT}$, it yields its corresponding $\text{freqIdx}$ into the $\text{DFT}$, such that
@@ -102,8 +131,11 @@ This works assuming that each workgroup shuffle is associated with the same
102131
$\text{localElementIndex}$ for every thread. The question now becomes, how does a thread know which value it has to send in this shuffle?
103132

104133
The functions
105-
`FFTMirrorTradeUtils::getNablaMirrorLocalInfo(uint32_t globalElementIndex)` and
106-
`FFTMirrorTradeUtils::getNablaMirrorGlobalInfo(uint32_t globalElementIndex)`
134+
```cpp
135+
NablaMirrorLocalInfo FFTMirrorTradeUtils::getNablaMirrorLocalInfo(uint32_t globalElementIndex);
136+
137+
NablaMirrorGlobalInfo FFTMirrorTradeUtils::getNablaMirrorGlobalInfo(uint32_t globalElementIndex);
138+
```
107139
handle this for you: given a $\text{globalElementIndex}$, `getNablaMirrorLocalInfo` returns a struct with a field `otherThreadID` (the one we will receive a value from in the shuffle) and a field `mirrorLocalIndex` which is the $\text{localElementIndex}$ *of the element we should write to the shared memory array*.
108140
109141
`getNablaMirrorGlobalInfo` returns the same info but with a `mirrorGlobalIndex` instead, so instead of returning the $\text{localElementIndex}$ of the element we have to send, it returns its $\text{globalElementIndex}$.

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ struct FFT<false, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device
324324
}
325325

326326

327-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::SmallFFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
327+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
328328
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
329329
{
330330
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -390,7 +390,7 @@ struct FFT<true, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_
390390
}
391391

392392

393-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::SmallFFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
393+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
394394
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
395395
{
396396
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -475,7 +475,6 @@ struct FFT<false, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupS
475475
accessor.set(loIx, lo);
476476
accessor.set(hiIx, hi);
477477
}
478-
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
479478
}
480479

481480
// do ElementsPerInvocation/2 small workgroup FFTs
@@ -522,7 +521,6 @@ struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSi
522521
[unroll]
523522
for (uint32_t stride = 2 * WorkgroupSize; stride < ElementsPerInvocation * WorkgroupSize; stride <<= 1)
524523
{
525-
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
526524
[unroll]
527525
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (ElementsPerInvocation / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
528526
{

0 commit comments

Comments
 (0)