Skip to content

Commit 27919be

Browse files
author
devsh
committed
add a whole bunch of enable_if to memory_accessor.hlsl
1 parent d2b824a commit 27919be

File tree

3 files changed

+105
-96
lines changed

3 files changed

+105
-96
lines changed

include/nbl/builtin/hlsl/member_test_macros.hlsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct has_method_##x<T BOOST_PP_REPEAT(n, NBL_TYPE_FWD, n), typename make_void<
108108
due to how we check function signatures at the moment
109109
*/
110110

111+
// TODO: these should probably generate without a namespace and be expected to be put inside a namespace
111112
#define GENERATE_METHOD_TESTER(x) \
112113
namespace nbl { \
113114
namespace hlsl { \
@@ -118,8 +119,10 @@ BOOST_PP_REPEAT(4, GENERATE_METHOD_TESTER_SPEC, x) \
118119
}}
119120

120121

121-
GENERATE_METHOD_TESTER(a)
122-
GENERATE_METHOD_TESTER(b)
122+
GENERATE_METHOD_TESTER(a) // TODO: remove
123+
GENERATE_METHOD_TESTER(b) // TODO: remove
124+
GENERATE_METHOD_TESTER(get)
125+
GENERATE_METHOD_TESTER(set)
123126

124127

125128
#endif

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 99 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,125 +6,131 @@
66

77
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
88

9+
// weird namespace placing, see the comment where the macro is defined
10+
GENERATE_METHOD_TESTER(atomicExchange)
11+
GENERATE_METHOD_TESTER(atomicCompSwap)
12+
GENERATE_METHOD_TESTER(atomicAnd)
13+
GENERATE_METHOD_TESTER(atomicOr)
14+
GENERATE_METHOD_TESTER(atomicXor)
15+
GENERATE_METHOD_TESTER(atomicAdd)
16+
GENERATE_METHOD_TESTER(atomicMin)
17+
GENERATE_METHOD_TESTER(atomicMax)
18+
GENERATE_METHOD_TESTER(workgroupExecutionAndMemoryBarrier)
19+
920
namespace nbl
1021
{
1122
namespace hlsl
1223
{
1324

14-
template<class BaseAccessor>
15-
struct MemoryAdaptor
25+
// TODO: flesh out and move to `nbl/builtin/hlsl/utility.hlsl`
26+
template<typename T1, typename T2>
27+
struct pair
1628
{
29+
using first_type = T1;
30+
using second_type = T2;
31+
32+
first_type first;
33+
second_type second;
34+
};
35+
36+
37+
// TODO: find some cool way to SFINAE the default into `_NBL_HLSL_WORKGROUP_SIZE_` if defined, and something like 1 otherwise
38+
template<class BaseAccessor, typename AccessType=uint32_t, typename IndexType=uint32_t, typename Strides=pair<integral_constant<IndexType,1>,integral_constant<IndexType,_NBL_HLSL_WORKGROUP_SIZE_> > >
39+
struct MemoryAdaptor // TODO: rename to something nicer like StructureOfArrays and add a `namespace accessor_adaptors`
40+
{
41+
using access_t = AccessType;
42+
using index_t = IndexType;
43+
NBL_CONSTEXPR index_t ElementStride = Strides::first_type::value;
44+
NBL_CONSTEXPR index_t SubElementStride = Strides::second_type::value;
45+
1746
BaseAccessor accessor;
1847

19-
// TODO: template atomic... then add static_asserts of `has_method<BaseAccessor,signature>::value`, do vectors and matrices in terms of each other
20-
uint get(const uint ix)
48+
access_t get(const index_t ix)
2149
{
22-
uint retVal;
23-
accessor.get(ix, retVal);
50+
access_t retVal;
51+
get<access_t>(ix,retVal);
2452
return retVal;
2553
}
2654

27-
template<typename Scalar>
28-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(Scalar) value)
55+
template<typename T>
56+
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> get(const index_t ix, NBL_REF_ARG(T) value)
2957
{
30-
uint32_t aux;
31-
accessor.get(ix, aux);
32-
value = bit_cast<Scalar, uint32_t>(aux);
33-
}
34-
template<typename Scalar>
35-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value)
36-
{
37-
uint32_t2 aux;
38-
accessor.get(ix, aux.x);
39-
accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, aux.y);
40-
value = bit_cast<vector<Scalar, 2>, uint32_t2>(aux);
41-
}
42-
template<typename Scalar>
43-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value)
44-
{
45-
uint32_t3 aux;
46-
accessor.get(ix, aux.x);
47-
accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, aux.y);
48-
accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, aux.z);
49-
value = bit_cast<vector<Scalar, 3>, uint32_t3>(aux);
50-
}
51-
template<typename Scalar>
52-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value)
53-
{
54-
uint32_t4 aux;
55-
accessor.get(ix, aux.x);
56-
accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, aux.y);
57-
accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, aux.z);
58-
accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, aux.w);
59-
value = bit_cast<vector<Scalar, 3>, uint32_t4>(aux);
58+
NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t);
59+
access_t aux[SubElementCount];
60+
for (uint64_t i=0; i<SubElementCount; i++)
61+
accessor.get(ix*ElementStride+i*SubElementStride,aux[i]);
62+
value = bit_cast<T,access_t[SubElementCount]>(aux);
6063
}
6164

62-
template<typename Scalar>
63-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const Scalar value) {accessor.set(ix, asuint(value));}
64-
template<typename Scalar>
65-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 2> value) {
66-
accessor.set(ix, asuint(value.x));
67-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
68-
}
69-
template<typename Scalar>
70-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 3> value) {
71-
accessor.set(ix, asuint(value.x));
72-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
73-
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z));
74-
}
75-
template<typename Scalar>
76-
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 4> value) {
77-
accessor.set(ix, asuint(value.x));
78-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
79-
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z));
80-
accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.w));
65+
template<typename T>
66+
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> set(const index_t ix, NBL_CONST_REF_ARG(T) value)
67+
{
68+
NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t);
69+
access_t aux[SubElementCount] = bit_cast<access_t[SubElementCount],T>(value);
70+
for (uint64_t i=0; i<SubElementCount; i++)
71+
accessor.set(ix*ElementStride+i*SubElementStride,aux[i]);
8172
}
8273

83-
void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
84-
orig = accessor.atomicAnd(ix, value);
85-
}
86-
void atomicAnd(const uint ix, const int value, NBL_REF_ARG(int) orig) {
87-
orig = asint(accessor.atomicAnd(ix, asuint(value)));
88-
}
89-
void atomicAnd(const uint ix, const float value, NBL_REF_ARG(float) orig) {
90-
orig = asfloat(accessor.atomicAnd(ix, asuint(value)));
91-
}
92-
void atomicOr(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
93-
orig = accessor.atomicOr(ix, value);
94-
}
95-
void atomicOr(const uint ix, const int value, NBL_REF_ARG(int) orig) {
96-
orig = asint(accessor.atomicOr(ix, asuint(value)));
97-
}
98-
void atomicOr(const uint ix, const float value, NBL_REF_ARG(float) orig) {
99-
orig = asfloat(accessor.atomicOr(ix, asuint(value)));
100-
}
101-
void atomicXor(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
102-
orig = accessor.atomicXor(ix, value);
74+
template<typename T, typename S=BaseAccessor>
75+
enable_if_t<
76+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicExchange<S,index_t,access_t>::return_type,access_t>,void
77+
> atomicExchange(const index_t ix, const T value, NBL_REF_ARG(T) orig)
78+
{
79+
orig = bit_cast<T,access_t>(accessor.atomicExchange(ix,bit_cast<access_t,T>(value)));
10380
}
104-
void atomicXor(const uint ix, const int value, NBL_REF_ARG(int) orig) {
105-
orig = asint(accessor.atomicXor(ix, asuint(value)));
81+
template<typename T, typename S=BaseAccessor>
82+
enable_if_t<
83+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicCompSwap<S,index_t,access_t,access_t>::return_type,access_t>,void
84+
> atomicCompSwap(const index_t ix, const T value, const T comp, NBL_REF_ARG(T) orig)
85+
{
86+
orig = bit_cast<T,access_t>(accessor.atomicCompSwap(ix,bit_cast<access_t,T>(comp),bit_cast<access_t,T>(value)));
10687
}
107-
void atomicXor(const uint ix, const float value, NBL_REF_ARG(float) orig) {
108-
orig = asfloat(accessor.atomicXor(ix, asuint(value)));
88+
89+
template<typename T, typename S=BaseAccessor>
90+
enable_if_t<
91+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicAnd<S,index_t,access_t>::return_type,access_t>,void
92+
> atomicAnd(const index_t ix, const T value, NBL_REF_ARG(T) orig)
93+
{
94+
orig = bit_cast<T,access_t>(accessor.atomicAnd(ix,bit_cast<access_t,T>(value)));
10995
}
110-
void atomicAdd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
111-
orig = accessor.atomicAdd(ix, value);
96+
template<typename T, typename S=BaseAccessor>
97+
enable_if_t<
98+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicOr<S,index_t,access_t>::return_type,access_t>,void
99+
> atomicOr(const index_t ix, const T value, NBL_REF_ARG(T) orig)
100+
{
101+
orig = bit_cast<T,access_t>(accessor.atomicOr(ix,bit_cast<access_t,T>(value)));
112102
}
113-
void atomicMin(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
114-
orig = accessor.atomicMin(ix, value);
103+
template<typename T, typename S=BaseAccessor>
104+
enable_if_t<
105+
sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicXor<S,index_t,access_t>::return_type,access_t>,void
106+
> atomicXor(const index_t ix, const T value, NBL_REF_ARG(T) orig)
107+
{
108+
orig = bit_cast<T,access_t>(accessor.atomicXor(ix,bit_cast<access_t,T>(value)));
115109
}
116-
void atomicMax(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
117-
orig = accessor.atomicMax(ix, value);
110+
111+
// This has the upside of never calling a `(uint32_t)(uint32_t,uint32_t)` overload of `atomicAdd` because it checks the return type!
112+
// If someone makes a `(float)(uint32_t,uint32_t)` they will break this detection code, but oh well.
113+
template<typename T>
114+
enable_if_t<is_same_v<has_method_atomicAdd<BaseAccessor,index_t,T>::return_type,T>,void> atomicAdd(const index_t ix, const T value, NBL_REF_ARG(T) orig)
115+
{
116+
orig = accessor.atomicAdd(ix,value);
118117
}
119-
void atomicExchange(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
120-
orig = accessor.atomicExchange(ix, value);
118+
template<typename T>
119+
enable_if_t<is_same_v<has_method_atomicMin<BaseAccessor,index_t,T>::return_type,T>,void> atomicMin(const index_t ix, const T value, NBL_REF_ARG(T) orig)
120+
{
121+
orig = accessor.atomicMin(ix,value);
121122
}
122-
void atomicCompSwap(const uint ix, const uint value, const uint comp, NBL_REF_ARG(uint) orig) {
123-
orig = accessor.atomicCompSwap(ix, comp, value);
123+
template<typename T>
124+
enable_if_t<is_same_v<has_method_atomicMax<BaseAccessor,index_t,T>::return_type,T>,void> atomicMax(const index_t ix, const T value, NBL_REF_ARG(T) orig)
125+
{
126+
orig = accessor.atomicMax(ix,value);
124127
}
125128

126-
// TODO: figure out the `enable_if` syntax for this
127-
void workgroupExecutionAndMemoryBarrier() {
129+
template<typename S=BaseAccessor>
130+
enable_if_t<
131+
is_same_v<S,BaseAccessor> && is_same_v<has_method_workgroupExecutionAndMemoryBarrier<S>::return_type,void>,void
132+
> workgroupExecutionAndMemoryBarrier()
133+
{
128134
accessor.workgroupExecutionAndMemoryBarrier();
129135
}
130136
};

src/nbl/builtin/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/ext/FullScreenTriangle/defaul
321321
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/text_rendering/msdf.hlsl")
322322
#memory
323323
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/memory.hlsl")
324-
324+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/memory_accessor.hlsl")
325325
#enums
326326
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/enums.hlsl")
327327

0 commit comments

Comments
 (0)