Skip to content

Commit 70b1caa

Browse files
Merge pull request #698 from Devsh-Graphics-Programming/nahim_complex_fft_workgroup
FFT Workgroup Ops
2 parents 2dc2bfb + 0a7e924 commit 70b1caa

File tree

16 files changed

+651
-86
lines changed

16 files changed

+651
-86
lines changed

include/nbl/builtin/hlsl/bda/__ref.hlsl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,20 @@ namespace nbl
1212
{
1313
namespace hlsl
1414
{
15+
16+
// TODO: make a common `nbl/builtin/hlsl/__ref.hlsl`
17+
// TODO: also refactor `bda::__base_ref` into just `__ref` and make it a typedef
18+
template<uint32_t StorageClass, typename T>
19+
using __spv_ptr_t = spirv::pointer_t<StorageClass,T>;
20+
21+
template<uint32_t StorageClass, typename T>
22+
[[vk::ext_instruction(spv::OpCopyObject)]]
23+
__spv_ptr_t<StorageClass,T> addrof([[vk::ext_reference]] T v);
24+
1525
namespace bda
1626
{
1727
template<typename T>
18-
using __spv_ptr_t = spirv::pointer_t<spv::StorageClassPhysicalStorageBuffer, T>;
28+
using __spv_ptr_t = spirv::pointer_t<spv::StorageClassPhysicalStorageBuffer,T>;
1929

2030
template<typename T>
2131
struct __ptr;

include/nbl/builtin/hlsl/bit.hlsl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ namespace hlsl
3232
{
3333

3434
template<class T, class U>
35-
T bit_cast(U val)
35+
enable_if_t<sizeof(T)==sizeof(U),T> bit_cast(U val)
3636
{
37-
static_assert(sizeof(T)==sizeof(U));
3837
return spirv::bitcast<T,U>(val);
3938
}
4039

@@ -92,7 +91,7 @@ uint16_t clz(uint64_t N)
9291
template<>
9392
uint16_t clz<1>(uint64_t N) { return uint16_t(1u-N&1); }
9493

95-
}
94+
} //namespace impl
9695

9796
template<typename T>
9897
uint16_t countl_zero(T n)

include/nbl/builtin/hlsl/complex.hlsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ const static complex_t< SCALAR > multiplies< complex_t< SCALAR > >::identity = {
191191
template<> \
192192
const static complex_t< SCALAR > divides< complex_t< SCALAR > >::identity = { promote< SCALAR , uint32_t>(1), promote< SCALAR , uint32_t>(0)};
193193

194+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t)
195+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t2)
196+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t3)
197+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t4)
194198
COMPLEX_ARITHMETIC_IDENTITIES(float32_t)
195199
COMPLEX_ARITHMETIC_IDENTITIES(float32_t2)
196200
COMPLEX_ARITHMETIC_IDENTITIES(float32_t3)
@@ -287,6 +291,10 @@ COMPLEX_COMPOUND_ASSIGN_IDENTITY(minus, SCALAR) \
287291
COMPLEX_COMPOUND_ASSIGN_IDENTITY(multiplies, SCALAR) \
288292
COMPLEX_COMPOUND_ASSIGN_IDENTITY(divides, SCALAR)
289293

294+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t)
295+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t2)
296+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t3)
297+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t4)
290298
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float32_t)
291299
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float32_t2)
292300
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float32_t3)

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ namespace fft
1414

1515
// Computes the kth element in the group of N roots of unity
1616
// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT)
17-
template<typename Scalar, bool inverse>
18-
complex_t<Scalar> twiddle(uint32_t k, uint32_t N)
17+
template<bool inverse, typename Scalar>
18+
complex_t<Scalar> twiddle(uint32_t k, uint32_t halfN)
1919
{
2020
complex_t<Scalar> retVal;
21-
const Scalar kthRootAngleRadians = 2.f * numbers::pi<Scalar> * Scalar(k) / Scalar(N);
21+
const Scalar kthRootAngleRadians = numbers::pi<Scalar> * Scalar(k) / Scalar(halfN);
2222
retVal.real( cos(kthRootAngleRadians) );
2323
if (! inverse)
2424
retVal.imag( sin(kthRootAngleRadians) );
@@ -27,7 +27,7 @@ complex_t<Scalar> twiddle(uint32_t k, uint32_t N)
2727
return retVal;
2828
}
2929

30-
template<typename Scalar, bool inverse>
30+
template<bool inverse, typename Scalar>
3131
struct DIX
3232
{
3333
static void radix2(NBL_CONST_REF_ARG(complex_t<Scalar>) twiddle, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
@@ -49,10 +49,10 @@ struct DIX
4949
};
5050

5151
template<typename Scalar>
52-
using DIT = DIX<Scalar, true>;
52+
using DIT = DIX<true, Scalar>;
5353

5454
template<typename Scalar>
55-
using DIF = DIX<Scalar, false>;
55+
using DIF = DIX<false, Scalar>;
5656
}
5757
}
5858
}

include/nbl/builtin/hlsl/glsl_compat/core.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ struct bitfieldExtract<T, false, true>
199199
}
200200
};
201201

202-
}
202+
} //namespace impl
203203

204204
template<typename T>
205205
T bitfieldExtract( T val, uint32_t offsetBits, uint32_t numBits )

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 125 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#define _NBL_BUILTIN_HLSL_MEMORY_ACCESSOR_INCLUDED_
66

77
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
8+
#include "nbl/builtin/hlsl/member_test_macros.hlsl"
89

910
// weird namespace placing, see the comment where the macro is defined
1011
GENERATE_METHOD_TESTER(atomicExchange)
@@ -33,104 +34,189 @@ struct pair
3334
second_type second;
3435
};
3536

37+
namespace accessor_adaptors
38+
{
39+
namespace impl
40+
{
41+
// only base class to use integral_constant because we need to use void to indicate a dynamic value and all values are valid
42+
template<typename IndexType, typename Offset>
43+
struct OffsetBase
44+
{
45+
NBL_CONSTEXPR IndexType offset = Offset::value;
46+
};
47+
template<typename IndexType>
48+
struct OffsetBase<IndexType,void>
49+
{
50+
IndexType offset;
51+
};
52+
53+
template<typename IndexType, uint64_t ElementStride, uint64_t SubElementStride, typename Offset>
54+
struct StructureOfArraysStrides
55+
{
56+
NBL_CONSTEXPR IndexType elementStride = ElementStride;
57+
NBL_CONSTEXPR IndexType subElementStride = SubElementStride;
58+
59+
//static_assert(elementStride>0 && subElementStride>0);
60+
};
61+
template<typename IndexType, typename Offset>
62+
struct StructureOfArraysStrides<IndexType,0,0,Offset> : OffsetBase<IndexType,Offset>
63+
{
64+
IndexType elementStride;
65+
IndexType subElementStride;
66+
};
67+
#if 0 // don't seem to be able to specialize one at a time
68+
template<typename IndexType, uint64_t ElementStride, typename Offset>
69+
struct StructureOfArraysStrides<IndexType,ElementStride,0,Offset> : OffsetBase<IndexType,Offset>
70+
{
71+
NBL_CONSTEXPR IndexType elementStride = ElementStride;
72+
IndexType subElementStride;
73+
};
74+
template<typename IndexType, uint64_t SubElementStride, typename Offset>
75+
struct StructureOfArraysStrides<IndexType,0,SubElementStride,Offset> : OffsetBase<IndexType,Offset>
76+
{
77+
IndexType elementStride;
78+
NBL_CONSTEXPR IndexType subElementStride = SubElementStride;
79+
};
80+
#endif
81+
82+
83+
template<typename IndexType, uint64_t ElementStride, uint64_t SubElementStride, typename Offset>
84+
struct StructureOfArraysBase : StructureOfArraysStrides<IndexType,ElementStride,SubElementStride,Offset>
85+
{
86+
IndexType getIx(const IndexType ix, const IndexType el)
87+
{
88+
using base_t = StructureOfArraysStrides<IndexType,ElementStride,SubElementStride,Offset>;
89+
return base_t::elementStride*ix+base_t::subElementStride*el+OffsetBase<IndexType,Offset>::offset;
90+
}
91+
};
92+
93+
// maybe we should have our own std::array
94+
template<typename T, uint64_t count>
95+
struct array
96+
{
97+
T data[count];
98+
};
99+
}
36100

37-
// TODO: find some cool way to SFINAE the default into `_NBL_HLSL_WORKGROUP_SIZE_` if defined, and something like 1 otherwise
38-
template<class BaseAccessor, typename AccessType, typename IndexType=uint32_t, typename Strides=pair<integral_constant<IndexType,1>,integral_constant<IndexType,_NBL_HLSL_WORKGROUP_SIZE_> > >
39-
struct MemoryAdaptor // TODO: rename to something nicer like StructureOfArrays and add a `namespace accessor_adaptors`
101+
// TODO: some CRTP thing to forward through atomics and barriers
102+
103+
// If you want static strides pass `Stride=pair<integral_constant<IndexType,ElementStride>,integral_constant<IndexType,SubElementStride> >`
104+
template<class BaseAccessor, typename AccessType, typename IndexType=uint32_t, uint64_t ElementStride=0, uint64_t SubElementStride=0, typename _Offset=integral_constant<IndexType,0> >
105+
struct StructureOfArrays : impl::StructureOfArraysBase<IndexType,ElementStride,SubElementStride,_Offset>
40106
{
107+
using base_t = impl::StructureOfArraysBase<IndexType,ElementStride,SubElementStride,_Offset>;
41108
// Question: should the `BaseAccessor` let us know what this is?
42109
using access_t = AccessType;
43110
using index_t = IndexType;
44-
NBL_CONSTEXPR index_t ElementStride = Strides::first_type::value;
45-
NBL_CONSTEXPR index_t SubElementStride = Strides::second_type::value;
46111

47112
BaseAccessor accessor;
48-
49-
access_t get(const index_t ix)
50-
{
51-
access_t retVal;
52-
get<access_t>(ix,retVal);
53-
return retVal;
54-
}
55113

56114
// Question: shall we go back to requiring a `access_t get(index_t)` on the `BaseAccessor`, then we could `enable_if` check the return type (via `has_method_get`) matches and we won't get Nasty HLSL copy-in copy-out conversions
57115
template<typename T>
58116
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> get(const index_t ix, NBL_REF_ARG(T) value)
59-
{
117+
{
60118
NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t);
61-
access_t aux[SubElementCount];
62-
for (uint64_t i=0; i<SubElementCount; i++)
63-
accessor.get(ix*ElementStride+i*SubElementStride,aux[i]);
64-
value = bit_cast<T,access_t[SubElementCount]>(aux);
119+
// `vector` for now, we'll use `array` later
120+
vector<access_t,SubElementCount> aux;
121+
for (index_t i=0; i<SubElementCount; i++)
122+
accessor.get(base_t::getIx(ix,i),aux[i]);
123+
value = bit_cast<T,vector<access_t,SubElementCount> >(aux);
65124
}
66125

67126
template<typename T>
68127
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> set(const index_t ix, NBL_CONST_REF_ARG(T) value)
69128
{
70129
NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t);
71-
access_t aux[SubElementCount] = bit_cast<access_t[SubElementCount],T>(value);
72-
for (uint64_t i=0; i<SubElementCount; i++)
73-
accessor.set(ix*ElementStride+i*SubElementStride,aux[i]);
130+
// `vector` for now, we'll use `array` later
131+
vector<access_t,SubElementCount> aux;
132+
aux = bit_cast<vector<access_t,SubElementCount>,T>(value);
133+
for (index_t i=0; i<SubElementCount; i++)
134+
accessor.set(base_t::getIx(ix,i),aux[i]);
135+
74136
}
75-
137+
76138
template<typename T, typename S=BaseAccessor>
77139
enable_if_t<
78-
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicExchange<S,index_t,access_t>::return_type,access_t>,void
140+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<typename has_method_atomicExchange<S,index_t,access_t>::return_type,access_t>,void
79141
> atomicExchange(const index_t ix, const T value, NBL_REF_ARG(T) orig)
80142
{
81-
orig = bit_cast<T,access_t>(accessor.atomicExchange(ix,bit_cast<access_t,T>(value)));
143+
orig = bit_cast<T,access_t>(accessor.atomicExchange(getIx(ix),bit_cast<access_t,T>(value)));
82144
}
83145
template<typename T, typename S=BaseAccessor>
84146
enable_if_t<
85-
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicCompSwap<S,index_t,access_t,access_t>::return_type,access_t>,void
147+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<typename has_method_atomicCompSwap<S,index_t,access_t,access_t>::return_type,access_t>,void
86148
> atomicCompSwap(const index_t ix, const T value, const T comp, NBL_REF_ARG(T) orig)
87149
{
88-
orig = bit_cast<T,access_t>(accessor.atomicCompSwap(ix,bit_cast<access_t,T>(comp),bit_cast<access_t,T>(value)));
150+
orig = bit_cast<T,access_t>(accessor.atomicCompSwap(getIx(ix),bit_cast<access_t,T>(comp),bit_cast<access_t,T>(value)));
89151
}
90152

91153
template<typename T, typename S=BaseAccessor>
92154
enable_if_t<
93-
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicAnd<S,index_t,access_t>::return_type,access_t>,void
155+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<typename has_method_atomicAnd<S,index_t,access_t>::return_type,access_t>,void
94156
> atomicAnd(const index_t ix, const T value, NBL_REF_ARG(T) orig)
95157
{
96-
orig = bit_cast<T,access_t>(accessor.atomicAnd(ix,bit_cast<access_t,T>(value)));
158+
orig = bit_cast<T,access_t>(accessor.atomicAnd(getIx(ix),bit_cast<access_t,T>(value)));
97159
}
98160
template<typename T, typename S=BaseAccessor>
99161
enable_if_t<
100-
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicOr<S,index_t,access_t>::return_type,access_t>,void
162+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<typename has_method_atomicOr<S,index_t,access_t>::return_type,access_t>,void
101163
> atomicOr(const index_t ix, const T value, NBL_REF_ARG(T) orig)
102164
{
103-
orig = bit_cast<T,access_t>(accessor.atomicOr(ix,bit_cast<access_t,T>(value)));
165+
orig = bit_cast<T,access_t>(accessor.atomicOr(getIx(ix),bit_cast<access_t,T>(value)));
104166
}
105167
template<typename T, typename S=BaseAccessor>
106168
enable_if_t<
107-
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicXor<S,index_t,access_t>::return_type,access_t>,void
169+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<typename has_method_atomicXor<S,index_t,access_t>::return_type,access_t>,void
108170
> atomicXor(const index_t ix, const T value, NBL_REF_ARG(T) orig)
109171
{
110-
orig = bit_cast<T,access_t>(accessor.atomicXor(ix,bit_cast<access_t,T>(value)));
172+
orig = bit_cast<T,access_t>(accessor.atomicXor(getIx(ix),bit_cast<access_t,T>(value)));
111173
}
112174

113175
// This has the upside of never calling a `(uint32_t)(uint32_t,uint32_t)` overload of `atomicAdd` because it checks the return type!
114176
// If someone makes a `(float)(uint32_t,uint32_t)` they will break this detection code, but oh well.
115177
template<typename T>
116-
enable_if_t<is_same_v<has_method_atomicAdd<BaseAccessor,index_t,T>::return_type,T>,void> atomicAdd(const index_t ix, const T value, NBL_REF_ARG(T) orig)
178+
enable_if_t<is_same_v<typename has_method_atomicAdd<BaseAccessor,index_t,T>::return_type,T>,void> atomicAdd(const index_t ix, const T value, NBL_REF_ARG(T) orig)
117179
{
118-
orig = accessor.atomicAdd(ix,value);
180+
orig = accessor.atomicAdd(getIx(ix),value);
119181
}
120182
template<typename T>
121-
enable_if_t<is_same_v<has_method_atomicMin<BaseAccessor,index_t,T>::return_type,T>,void> atomicMin(const index_t ix, const T value, NBL_REF_ARG(T) orig)
183+
enable_if_t<is_same_v<typename has_method_atomicMin<BaseAccessor,index_t,T>::return_type,T>,void> atomicMin(const index_t ix, const T value, NBL_REF_ARG(T) orig)
122184
{
123-
orig = accessor.atomicMin(ix,value);
185+
orig = accessor.atomicMin(getIx(ix),value);
124186
}
125187
template<typename T>
126-
enable_if_t<is_same_v<has_method_atomicMax<BaseAccessor,index_t,T>::return_type,T>,void> atomicMax(const index_t ix, const T value, NBL_REF_ARG(T) orig)
188+
enable_if_t<is_same_v<typename has_method_atomicMax<BaseAccessor,index_t,T>::return_type,T>,void> atomicMax(const index_t ix, const T value, NBL_REF_ARG(T) orig)
189+
{
190+
orig = accessor.atomicMax(getIx(ix),value);
191+
}
192+
193+
template<typename S=BaseAccessor>
194+
enable_if_t<
195+
is_same_v<S,BaseAccessor> && is_same_v<typename has_method_workgroupExecutionAndMemoryBarrier<S>::return_type,void>,void
196+
> workgroupExecutionAndMemoryBarrier()
127197
{
128-
orig = accessor.atomicMax(ix,value);
198+
accessor.workgroupExecutionAndMemoryBarrier();
129199
}
200+
};
201+
202+
// ---------------------------------------------- Offset Accessor ----------------------------------------------------
203+
204+
template<class BaseAccessor, typename IndexType=uint32_t, typename _Offset=void>
205+
struct Offset : impl::OffsetBase<IndexType,_Offset>
206+
{
207+
using base_t = impl::OffsetBase<IndexType,_Offset>;
208+
209+
BaseAccessor accessor;
210+
211+
template <typename T>
212+
void set(uint32_t idx, T value) {accessor.set(idx+base_t::offset,value); }
213+
214+
template <typename T>
215+
void get(uint32_t idx, NBL_REF_ARG(T) value) {accessor.get(idx+base_t::offset,value);}
130216

131217
template<typename S=BaseAccessor>
132218
enable_if_t<
133-
is_same_v<S,BaseAccessor> && is_same_v<has_method_workgroupExecutionAndMemoryBarrier<S>::return_type,void>,void
219+
is_same_v<S,BaseAccessor> && is_same_v<typename has_method_workgroupExecutionAndMemoryBarrier<S>::return_type,void>,void
134220
> workgroupExecutionAndMemoryBarrier()
135221
{
136222
accessor.workgroupExecutionAndMemoryBarrier();
@@ -139,5 +225,5 @@ struct MemoryAdaptor // TODO: rename to something nicer like StructureOfArrays a
139225

140226
}
141227
}
142-
228+
}
143229
#endif

include/nbl/builtin/hlsl/mpl.hlsl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ struct rotr
6363
static const T value = (S >= 0) ? ((X >> r) | (X << (N - r))) : (X << (-r)) | (X >> (N - (-r)));
6464
};
6565

66+
template<uint64_t N>
67+
struct is_pot : bool_constant< (N > 0 && !(N & (N - 1))) > {};
68+
69+
template<uint64_t N>
70+
NBL_CONSTEXPR_STATIC_INLINE bool is_pot_v = is_pot<N>::value;
6671

6772
}
6873
}

include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer,vk::Literal<vk::integra
5454

5555
//! General Operations
5656

57+
// The holy operation that makes addrof possible
58+
template<uint32_t StorageClass, typename T>
59+
[[vk::ext_instruction(spv::OpCopyObject)]]
60+
pointer_t<StorageClass,T> copyObject([[vk::ext_reference]] T v);
61+
5762
// Here's the thing with atomics, it's not only the data type that dictates whether you can do an atomic or not.
5863
// It's the storage class that has the most effect (shared vs storage vs image) and we can't check that easily
5964
template<typename T> // integers operate on 2s complement so same op for signed and unsigned

0 commit comments

Comments
 (0)