Skip to content

Commit db70f9c

Browse files
author
devsh
committed
Merge remote-tracking branch 'origin/more_fft_utils' into bindless_blit
2 parents 6796fa6 + 0365e30 commit db70f9c

File tree

5 files changed

+182
-79
lines changed

5 files changed

+182
-79
lines changed

3rdparty/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ if(_NBL_COMPILE_WITH_OPEN_EXR_)
235235
set(BUILD_TESTING ${_OLD_BUILD_TESTING})
236236
endif()
237237

238+
238239
#gli
239240
option(_NBL_COMPILE_WITH_GLI_ "Build with GLI library" ON)
240241
if(_NBL_COMPILE_WITH_GLI_)

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ complex_t<Scalar> twiddle(uint32_t k, uint32_t halfN)
2121
const Scalar kthRootAngleRadians = numbers::pi<Scalar> * Scalar(k) / Scalar(halfN);
2222
retVal.real( cos(kthRootAngleRadians) );
2323
if (! inverse)
24-
retVal.imag( sin(kthRootAngleRadians) );
25-
else
2624
retVal.imag( sin(-kthRootAngleRadians) );
25+
else
26+
retVal.imag( sin(kthRootAngleRadians) );
2727
return retVal;
2828
}
2929

include/nbl/builtin/hlsl/functional.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace nbl
1313
{
1414
namespace hlsl
1515
{
16-
#ifdef __HLSL_VERSION // CPP
16+
#ifdef __HLSL_VERSION // HLSL
1717
template<uint32_t StorageClass, typename T>
1818
using __spv_ptr_t = spirv::pointer_t<StorageClass,T>;
1919

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 108 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,85 +23,125 @@ namespace fft
2323
{
2424

2525
// ---------------------------------- Utils -----------------------------------------------
26-
template<typename SharedMemoryAdaptor, typename Scalar>
27-
struct exchangeValues;
2826

29-
template<typename SharedMemoryAdaptor>
30-
struct exchangeValues<SharedMemoryAdaptor, float16_t>
27+
// No need to expose these
28+
namespace impl
3129
{
32-
static void __call(NBL_REF_ARG(complex_t<float16_t>) lo, NBL_REF_ARG(complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
30+
template<typename SharedMemoryAdaptor, typename Scalar>
31+
struct exchangeValues
3332
{
34-
const bool topHalf = bool(threadID & stride);
35-
// Pack two halves into a single uint32_t
36-
uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real(), lo.imag()) : float16_t2 (hi.real(), hi.imag()));
37-
shuffleXor<SharedMemoryAdaptor, uint32_t>::__call(toExchange, stride, sharedmemAdaptor);
38-
float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
39-
if (topHalf)
33+
static void __call(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
4034
{
41-
lo.real(exchanged.x);
42-
lo.imag(exchanged.y);
35+
const bool topHalf = bool(threadID & stride);
36+
// Pack into float vector because ternary operator does not support structs
37+
vector<Scalar, 2> exchanged = topHalf ? vector<Scalar, 2>(lo.real(), lo.imag()) : vector<Scalar, 2>(hi.real(), hi.imag());
38+
shuffleXor<SharedMemoryAdaptor, vector<Scalar, 2> >(exchanged, stride, sharedmemAdaptor);
39+
if (topHalf)
40+
{
41+
lo.real(exchanged.x);
42+
lo.imag(exchanged.y);
43+
}
44+
else
45+
{
46+
hi.real(exchanged.x);
47+
hi.imag(exchanged.y);
48+
}
4349
}
44-
else
45-
{
46-
hi.real(exchanged.x);
47-
lo.imag(exchanged.y);
48-
}
49-
}
50-
};
50+
};
5151

52-
template<typename SharedMemoryAdaptor>
53-
struct exchangeValues<SharedMemoryAdaptor, float32_t>
54-
{
55-
static void __call(NBL_REF_ARG(complex_t<float32_t>) lo, NBL_REF_ARG(complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
52+
template<uint16_t N, uint16_t H>
53+
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftRightHigher(uint32_t i)
5654
{
57-
const bool topHalf = bool(threadID & stride);
58-
// pack into `float32_t2` because ternary operator doesn't support structs
59-
float32_t2 exchanged = topHalf ? float32_t2(lo.real(), lo.imag()) : float32_t2(hi.real(), hi.imag());
60-
shuffleXor<SharedMemoryAdaptor, float32_t2>::__call(exchanged, stride, sharedmemAdaptor);
61-
if (topHalf)
62-
{
63-
lo.real(exchanged.x);
64-
lo.imag(exchanged.y);
65-
}
66-
else
67-
{
68-
hi.real(exchanged.x);
69-
hi.imag(exchanged.y);
70-
}
55+
// Highest H bits are numbered N-1 through N - H
56+
// N - H is then the middle bit
57+
// Lowest bits numbered from 0 through N - H - 1
58+
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
59+
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
60+
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask);
61+
62+
uint32_t low = i & lowMask;
63+
uint32_t mid = i & midMask;
64+
uint32_t high = i & highMask;
65+
66+
high >>= 1;
67+
mid <<= H - 1;
68+
69+
return mid | high | low;
7170
}
72-
};
7371

74-
template<typename SharedMemoryAdaptor>
75-
struct exchangeValues<SharedMemoryAdaptor, float64_t>
76-
{
77-
static void __call(NBL_REF_ARG(complex_t<float64_t>) lo, NBL_REF_ARG(complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
72+
template<uint16_t N, uint16_t H>
73+
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftLeftHigher(uint32_t i)
7874
{
79-
const bool topHalf = bool(threadID & stride);
80-
// pack into `float64_t2` because ternary operator doesn't support structs
81-
float64_t2 exchanged = topHalf ? float64_t2(lo.real(), lo.imag()) : float64_t2(hi.real(), hi.imag());
82-
shuffleXor<SharedMemoryAdaptor, float64_t2 >::__call(exchanged, stride, sharedmemAdaptor);
83-
if (topHalf)
84-
{
85-
lo.real(exchanged.x);
86-
lo.imag(exchanged.y);
87-
}
88-
else
89-
{
90-
hi.real(exchanged.x);
91-
hi.imag(exchanged.y);
92-
}
75+
// Highest H bits are numbered N-1 through N - H
76+
// N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
77+
// Lowest bits numbered from 0 through N - H - 1
78+
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
79+
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1);
80+
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask);
81+
82+
uint32_t low = i & lowMask;
83+
uint32_t mid = i & midMask;
84+
uint32_t high = i & highMask;
85+
86+
mid <<= 1;
87+
high >>= H - 1;
88+
89+
return mid | high | low;
9390
}
94-
};
91+
} //namespace impl
9592

9693
// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
97-
template <typename scalar_t, uint32_t WorkgroupSize>
94+
template <typename scalar_t, uint16_t WorkgroupSize>
9895
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) * WorkgroupSize;
9996

97+
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
98+
template<typename Scalar>
99+
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
100+
{
101+
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
102+
hi = rotateRight<Scalar>(lo - conj(hi)) * Scalar(0.5);
103+
lo = x;
104+
}
105+
106+
template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
107+
struct FFTIndexingUtils
108+
{
109+
// This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
110+
// This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
111+
static uint32_t getDFTIndex(uint32_t outputIdx)
112+
{
113+
return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFTSizeLog2));
114+
}
115+
116+
// This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
117+
// It is essentially the inverse of `getDFTIndex`
118+
static uint32_t getNablaIndex(uint32_t freqIdx)
119+
{
120+
return glsl::bitfieldReverse<uint32_t>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(freqIdx)) >> (32 - FFTSizeLog2);
121+
}
122+
123+
// Mirrors an index about the Nyquist frequency in the DFT order
124+
static uint32_t getDFTMirrorIndex(uint32_t idx)
125+
{
126+
return (FFTSize - idx) & (FFTSize - 1);
127+
}
128+
129+
// Given an index `idx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
130+
static uint32_t getNablaMirrorIndex(uint32_t idx)
131+
{
132+
return getNablaIndex(getDFTMirrorIndex(getDFTIndex(idx)));
133+
}
134+
135+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136+
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137+
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(WorkgroupSize) * uint32_t(ElementsPerInvocation);
138+
};
139+
100140
} //namespace fft
101141

102142
// ----------------------------------- End Utils -----------------------------------------------
103143

104-
template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void>
144+
template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void>
105145
struct FFT;
106146

107147
// For the FFT methods below, we assume:
@@ -121,13 +161,13 @@ struct FFT;
121161
// * void workgroupExecutionAndMemoryBarrier();
122162

123163
// 2 items per invocation forward specialization
124-
template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
164+
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
125165
struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
126166
{
127167
template<typename SharedMemoryAdaptor>
128168
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
129169
{
130-
fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
170+
fft::impl::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
131171

132172
// Get twiddle with k = threadID mod stride, halfN = stride
133173
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID & (stride - 1), stride), lo, hi);
@@ -167,7 +207,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
167207
}
168208

169209
// special last workgroup-shuffle
170-
fft::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
210+
fft::impl::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
171211

172212
// Remember to update the accessor's state
173213
sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -185,7 +225,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
185225

186226

187227
// 2 items per invocation inverse specialization
188-
template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
228+
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
189229
struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
190230
{
191231
template<typename SharedMemoryAdaptor>
@@ -194,7 +234,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
194234
// Get twiddle with k = threadID mod stride, halfN = stride
195235
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID & (stride - 1), stride), lo, hi);
196236

197-
fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
237+
fft::impl::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
198238
}
199239

200240

@@ -223,7 +263,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
223263
sharedmemAdaptor.accessor = sharedmemAccessor;
224264

225265
// special first workgroup-shuffle
226-
fft::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
266+
fft::impl::exchangeValues<adaptor_t, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
227267

228268
// The bigger steps
229269
[unroll]
@@ -251,7 +291,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
251291
};
252292

253293
// Forward FFT
254-
template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
294+
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
255295
struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
256296
{
257297
template<typename Accessor, typename SharedMemoryAccessor>
@@ -294,7 +334,7 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
294334
};
295335

296336
// Inverse FFT
297-
template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
337+
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
298338
struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
299339
{
300340
template<typename Accessor, typename SharedMemoryAccessor>

include/nbl/builtin/hlsl/workgroup/shuffle.hlsl

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define _NBL_BUILTIN_HLSL_WORKGROUP_SHUFFLE_INCLUDED_
33

44
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
5+
#include "nbl/builtin/hlsl/functional.hlsl"
56

67
// TODO: Add other shuffles
78

@@ -14,26 +15,87 @@ namespace hlsl
1415
namespace workgroup
1516
{
1617

18+
// ------------------------------------- Skeletons for implementing other Shuffles --------------------------------
19+
1720
template<typename SharedMemoryAdaptor, typename T>
18-
struct shuffleXor
21+
struct Shuffle
22+
{
23+
static void __call(NBL_REF_ARG(T) value, uint32_t storeIdx, uint32_t loadIdx, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
24+
{
25+
// TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
26+
sharedmemAdaptor.template set<T>(storeIdx, value);
27+
28+
// Wait until all writes are done before reading
29+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
30+
31+
sharedmemAdaptor.template get<T>(loadIdx, value);
32+
}
33+
34+
// By default store to threadID in the workgroup
35+
static void __call(NBL_REF_ARG(T) value, uint32_t loadIdx, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
36+
{
37+
__call(value, uint32_t(SubgroupContiguousIndex()), loadIdx, sharedmemAdaptor);
38+
}
39+
};
40+
41+
template<class UnOp, typename SharedMemoryAdaptor, typename T>
42+
struct ShuffleUnOp
43+
{
44+
static void __call(NBL_REF_ARG(T) value, uint32_t a, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
45+
{
46+
UnOp unop;
47+
// TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
48+
sharedmemAdaptor.template set<T>(a, value);
49+
50+
// Wait until all writes are done before reading
51+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
52+
53+
sharedmemAdaptor.template get<T>(unop(a), value);
54+
}
55+
56+
// By default store to threadID's index and load from unop(threadID)
57+
static void __call(NBL_REF_ARG(T) value, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
58+
{
59+
__call(value, uint32_t(SubgroupContiguousIndex()), sharedmemAdaptor);
60+
}
61+
};
62+
63+
template<class BinOp, typename SharedMemoryAdaptor, typename T>
64+
struct ShuffleBinOp
1965
{
20-
static void __call(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
66+
static void __call(NBL_REF_ARG(T) value, uint32_t a, uint32_t b, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
2167
{
68+
BinOp binop;
2269
// TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
23-
sharedmemAdaptor.template set<T>(threadID, value);
24-
70+
sharedmemAdaptor.template set<T>(a, value);
71+
2572
// Wait until all writes are done before reading
2673
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
27-
28-
sharedmemAdaptor.template get<T>(threadID ^ mask, value);
74+
75+
sharedmemAdaptor.template get<T>(binop(a, b), value);
2976
}
3077

31-
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
78+
// By default first argument of binary op is the thread's ID in the workgroup
79+
static void __call(NBL_REF_ARG(T) value, uint32_t b, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
3280
{
33-
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAdaptor);
81+
__call(value, uint32_t(SubgroupContiguousIndex()), b, sharedmemAdaptor);
3482
}
3583
};
3684

85+
// ------------------------------------------ ShuffleXor ---------------------------------------------------------------
86+
87+
template<typename SharedMemoryAdaptor, typename T>
88+
void shuffleXor(NBL_REF_ARG(T) value, uint32_t threadID, uint32_t mask, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
89+
{
90+
return ShuffleBinOp<bit_xor<uint32_t>, SharedMemoryAdaptor, T>::__call(value, threadID, mask, sharedmemAdaptor);
91+
}
92+
93+
template<typename SharedMemoryAdaptor, typename T>
94+
void shuffleXor(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
95+
{
96+
return ShuffleBinOp<bit_xor<uint32_t>, SharedMemoryAdaptor, T>::__call(value, mask, sharedmemAdaptor);
97+
}
98+
3799
}
38100
}
39101
}

0 commit comments

Comments
 (0)