Skip to content

Commit fdb7904

Browse files
committed
Move some HLSL stuff to CPP-shared
1 parent 6401e53 commit fdb7904

File tree

5 files changed

+135
-110
lines changed

5 files changed

+135
-110
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef _NBL_BUILTIN_HLSL_BITREVERSE_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_BITREVERSE_INCLUDED_
3+
4+
5+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
6+
7+
namespace nbl
8+
{
9+
namespace hlsl
10+
{
11+
12+
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_unsigned_v<T>&& Bits <= sizeof(T) * 8)
13+
/**
14+
* @brief Takes the binary representation of `value` as a string of `Bits` bits and returns a value of the same type resulting from reversing the string
15+
*
16+
* @tparam T Type of the value to operate on.
17+
* @tparam Bits The length of the string of bits used to represent `value`.
18+
*
19+
* @param [in] value The value to bitreverse.
20+
*/
21+
T bitReverseAs(T value)
22+
{
23+
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - Bits));
24+
}
25+
26+
template<typename T NBL_FUNC_REQUIRES(is_unsigned_v<T>)
27+
/**
28+
* @brief Takes the binary representation of `value` and returns a value of the same type resulting from reversing the string of bits as if it was `bits` long.
29+
* Keep in mind `bits` cannot exceed `8 * sizeof(T)`.
30+
*
31+
* @tparam T type of the value to operate on.
32+
*
33+
* @param [in] value The value to bitreverse.
34+
* @param [in] bits The length of the string of bits used to represent `value`.
35+
*/
36+
T bitReverseAs(T value, uint16_t bits)
37+
{
38+
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - bits));
39+
}
40+
41+
42+
}
43+
}
44+
45+
46+
47+
#endif

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -124,35 +124,6 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
124124
lo = x;
125125
}
126126

127-
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_unsigned_v<T>&& Bits <= sizeof(T) * 8)
128-
/**
129-
* @brief Takes the binary representation of `value` as a string of `Bits` bits and returns a value of the same type resulting from reversing the string
130-
*
131-
* @tparam T Type of the value to operate on.
132-
* @tparam Bits The length of the string of bits used to represent `value`.
133-
*
134-
* @param [in] value The value to bitreverse.
135-
*/
136-
T bitReverseAs(T value)
137-
{
138-
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - Bits));
139-
}
140-
141-
template<typename T NBL_FUNC_REQUIRES(is_unsigned_v<T>)
142-
/**
143-
* @brief Takes the binary representation of `value` and returns a value of the same type resulting from reversing the string of bits as if it was `bits` long.
144-
* Keep in mind `bits` cannot exceed `8 * sizeof(T)`.
145-
*
146-
* @tparam T type of the value to operate on.
147-
*
148-
* @param [in] value The value to bitreverse.
149-
* @param [in] bits The length of the string of bits used to represent `value`.
150-
*/
151-
T bitReverseAs(T value, uint16_t bits)
152-
{
153-
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - bits));
154-
}
155-
156127
}
157128
}
158129
}

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 85 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
22
#include <nbl/builtin/hlsl/concepts.hlsl>
33
#include <nbl/builtin/hlsl/fft/common.hlsl>
4+
#include <nbl/builtin/hlsl/bitreverse.hlsl>
45

56
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
67
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
@@ -77,6 +78,83 @@ inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint
7778
}
7879
}
7980

81+
namespace impl
82+
{
83+
template<uint16_t N, uint16_t H>
84+
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftRightHigher(uint32_t i)
85+
{
86+
// Highest H bits are numbered N-1 through N - H
87+
// N - H is then the middle bit
88+
// Lowest bits numbered from 0 through N - H - 1
89+
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
90+
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
91+
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask);
92+
93+
uint32_t low = i & lowMask;
94+
uint32_t mid = i & midMask;
95+
uint32_t high = i & highMask;
96+
97+
high >>= 1;
98+
mid <<= H - 1;
99+
100+
return mid | high | low;
101+
}
102+
103+
template<uint16_t N, uint16_t H>
104+
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftLeftHigher(uint32_t i)
105+
{
106+
// Highest H bits are numbered N-1 through N - H
107+
// N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
108+
// Lowest bits numbered from 0 through N - H - 1
109+
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
110+
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1);
111+
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask);
112+
113+
uint32_t low = i & lowMask;
114+
uint32_t mid = i & midMask;
115+
uint32_t high = i & highMask;
116+
117+
mid <<= 1;
118+
high >>= H - 1;
119+
120+
return mid | high | low;
121+
}
122+
} //namespace impl
123+
124+
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
125+
struct FFTIndexingUtils
126+
{
127+
// This function maps the index `outputIdx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[outputIdx]`
128+
// This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
129+
static uint32_t getDFTIndex(uint32_t outputIdx)
130+
{
131+
return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(hlsl::bitReverseAs<uint32_t, FFTSizeLog2>(outputIdx));
132+
}
133+
134+
// This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
135+
// It is essentially the inverse of `getDFTIndex`
136+
static uint32_t getNablaIndex(uint32_t freqIdx)
137+
{
138+
return hlsl::bitReverseAs<uint32_t, FFTSizeLog2>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(freqIdx));
139+
}
140+
141+
// Mirrors an index about the Nyquist frequency in the DFT order
142+
static uint32_t getDFTMirrorIndex(uint32_t freqIdx)
143+
{
144+
return (FFTSize - freqIdx) & (FFTSize - 1);
145+
}
146+
147+
// Given an index `outputIdx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
148+
static uint32_t getNablaMirrorIndex(uint32_t outputIdx)
149+
{
150+
return getNablaIndex(getDFTMirrorIndex(getDFTIndex(outputIdx)));
151+
}
152+
153+
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + WorkgroupSizeLog2;
154+
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(1) << FFTSizeLog2;
155+
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = uint32_t(1) << WorkgroupSizeLog2;
156+
};
157+
80158
}
81159
}
82160
}
@@ -135,76 +213,12 @@ namespace impl
135213
}
136214
}
137215
};
138-
139-
template<uint16_t N, uint16_t H>
140-
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftRightHigher(uint32_t i)
141-
{
142-
// Highest H bits are numbered N-1 through N - H
143-
// N - H is then the middle bit
144-
// Lowest bits numbered from 0 through N - H - 1
145-
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
146-
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
147-
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask);
148-
149-
uint32_t low = i & lowMask;
150-
uint32_t mid = i & midMask;
151-
uint32_t high = i & highMask;
152-
153-
high >>= 1;
154-
mid <<= H - 1;
155-
156-
return mid | high | low;
157-
}
158-
159-
template<uint16_t N, uint16_t H>
160-
enable_if_t<(H <= N) && (N < 32), uint32_t> circularBitShiftLeftHigher(uint32_t i)
161-
{
162-
// Highest H bits are numbered N-1 through N - H
163-
// N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
164-
// Lowest bits numbered from 0 through N - H - 1
165-
NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1;
166-
NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1);
167-
NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask);
168-
169-
uint32_t low = i & lowMask;
170-
uint32_t mid = i & midMask;
171-
uint32_t high = i & highMask;
172-
173-
mid <<= 1;
174-
high >>= H - 1;
175-
176-
return mid | high | low;
177-
}
178216
} //namespace impl
179217

180218
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
181-
struct FFTIndexingUtils
219+
struct FFTMirrorTradeUtils
182220
{
183-
// This function maps the index `outputIdx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[outputIdx]`
184-
// This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
185-
static uint32_t getDFTIndex(uint32_t outputIdx)
186-
{
187-
return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(hlsl::fft::bitReverseAs<uint32_t, FFTSizeLog2>(outputIdx));
188-
}
189-
190-
// This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
191-
// It is essentially the inverse of `getDFTIndex`
192-
static uint32_t getNablaIndex(uint32_t freqIdx)
193-
{
194-
return hlsl::fft::bitReverseAs<uint32_t, FFTSizeLog2>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1>(freqIdx));
195-
}
196-
197-
// Mirrors an index about the Nyquist frequency in the DFT order
198-
static uint32_t getDFTMirrorIndex(uint32_t freqIdx)
199-
{
200-
return (FFTSize - freqIdx) & (FFTSize - 1);
201-
}
202-
203-
// Given an index `outputIdx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
204-
static uint32_t getNablaMirrorIndex(uint32_t outputIdx)
205-
{
206-
return getNablaIndex(getDFTMirrorIndex(getDFTIndex(outputIdx)));
207-
}
221+
using indexing_utils_t = FFTIndexingUtils<ElementsPerInvocationLog2, WorkgroupSizeLog2>;
208222

209223
// When unpacking an FFT of two packed signals, given a `globalElementIndex` you need its "mirror index" to unpack the value at NablaFFT[globalElementIndex].
210224
// The function above has you covered in that sense, but what also happens is that not only does the thread holding `NablaFFT[globalElementIndex]` need its mirror value
@@ -216,10 +230,10 @@ struct FFTIndexingUtils
216230
uint32_t otherThreadID;
217231
uint32_t mirrorLocalIndex;
218232
};
219-
233+
220234
static NablaMirrorLocalInfo getNablaMirrorLocalInfo(uint32_t globalElementIndex)
221235
{
222-
const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex(globalElementIndex);
236+
const uint32_t otherElementIndex = indexing_utils_t::getNablaMirrorIndex(globalElementIndex);
223237
const uint32_t mirrorLocalIndex = otherElementIndex / WorkgroupSize;
224238
const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1);
225239
const NablaMirrorLocalInfo info = { otherThreadID, mirrorLocalIndex };
@@ -235,23 +249,13 @@ struct FFTIndexingUtils
235249

236250
static NablaMirrorGlobalInfo getNablaMirrorGlobalInfo(uint32_t globalElementIndex)
237251
{
238-
const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex(globalElementIndex);
252+
const uint32_t otherElementIndex = indexing_utils_t::getNablaMirrorIndex(globalElementIndex);
239253
const uint32_t mirrorGlobalIndex = glsl::bitfieldInsert<uint32_t>(otherElementIndex, workgroup::SubgroupContiguousIndex(), 0, uint32_t(WorkgroupSizeLog2));
240254
const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1);
241255
const NablaMirrorGlobalInfo info = { otherThreadID, mirrorGlobalIndex };
242256
return info;
243257
}
244258

245-
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + WorkgroupSizeLog2;
246-
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(1) << FFTSizeLog2;
247-
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = uint32_t(1) << WorkgroupSizeLog2;
248-
};
249-
250-
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
251-
struct FFTMirrorTradeUtils
252-
{
253-
using indexing_utils_t = FFTIndexingUtils<ElementsPerInvocationLog2, WorkgroupSizeLog2>;
254-
using mirror_info_t = typename indexing_utils_t::NablaMirrorGlobalInfo;
255259
// If trading elements when, for example, unpacking real FFTs, you might do so from within your accessor or from outside.
256260
// If doing so from within your accessor, particularly if using a preloaded accessor, you might want to do this yourself by
257261
// using FFTIndexingUtils::getNablaMirrorTradeInfo and trading the elements yourself (an example of how to set this up is given in
@@ -261,7 +265,7 @@ struct FFTMirrorTradeUtils
261265
template<typename scalar_t, typename fft_array_accessor_t, typename shared_memory_adaptor_t>
262266
static complex_t<scalar_t> getNablaMirror(uint32_t globalElementIndex, fft_array_accessor_t arrayAccessor, shared_memory_adaptor_t sharedmemAdaptor)
263267
{
264-
const mirror_info_t mirrorInfo = indexing_utils_t::getNablaMirrorGlobalInfo(globalElementIndex);
268+
const NablaMirrorGlobalInfo mirrorInfo = getNablaMirrorGlobalInfo(globalElementIndex);
265269
complex_t<scalar_t> toTrade = arrayAccessor.get(mirrorInfo.mirrorGlobalIndex);
266270
vector<scalar_t, 2> toTradeVector = { toTrade.real(), toTrade.imag() };
267271
workgroup::Shuffle<shared_memory_adaptor_t, vector<scalar_t, 2> >::__call(toTradeVector, mirrorInfo.otherThreadID, sharedmemAdaptor);
@@ -271,6 +275,7 @@ struct FFTMirrorTradeUtils
271275
}
272276

273277
NBL_CONSTEXPR_STATIC_INLINE indexing_utils_t IndexingUtils;
278+
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = indexing_utils_t::WorkgroupSize;
274279
};
275280

276281

src/nbl/builtin/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,5 +353,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/mip_mapped
353353
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/storable_image.hlsl")
354354
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/fft.hlsl")
355355

356+
# temporary (delete once replaced)
357+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/bitreverse.hlsl")
356358

357359
ADD_CUSTOM_BUILTIN_RESOURCES(nblBuiltinResourceData NBL_RESOURCES_TO_EMBED "${NBL_ROOT_PATH}/include" "nbl/builtin" "nbl::builtin" "${NBL_ROOT_PATH_BINARY}/include" "${NBL_ROOT_PATH_BINARY}/src" "STATIC" "INTERNAL")

0 commit comments

Comments
 (0)