|
6 | 6 |
|
7 | 7 | #include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
|
8 | 8 |
|
| 9 | +// weird namespace placing, see the comment where the macro is defined |
| 10 | +GENERATE_METHOD_TESTER(atomicExchange) |
| 11 | +GENERATE_METHOD_TESTER(atomicCompSwap) |
| 12 | +GENERATE_METHOD_TESTER(atomicAnd) |
| 13 | +GENERATE_METHOD_TESTER(atomicOr) |
| 14 | +GENERATE_METHOD_TESTER(atomicXor) |
| 15 | +GENERATE_METHOD_TESTER(atomicAdd) |
| 16 | +GENERATE_METHOD_TESTER(atomicMin) |
| 17 | +GENERATE_METHOD_TESTER(atomicMax) |
| 18 | +GENERATE_METHOD_TESTER(workgroupExecutionAndMemoryBarrier) |
| 19 | + |
9 | 20 | namespace nbl
|
10 | 21 | {
|
11 | 22 | namespace hlsl
|
12 | 23 | {
|
13 | 24 |
|
14 |
| -template<class BaseAccessor> |
15 |
| -struct MemoryAdaptor |
| 25 | +// TODO: flesh out and move to `nbl/builtin/hlsl/utility.hlsl` |
| 26 | +template<typename T1, typename T2> |
| 27 | +struct pair |
16 | 28 | {
|
| 29 | + using first_type = T1; |
| 30 | + using second_type = T2; |
| 31 | + |
| 32 | + first_type first; |
| 33 | + second_type second; |
| 34 | +}; |
| 35 | + |
| 36 | + |
| 37 | +// TODO: find some cool way to SFINAE the default into `_NBL_HLSL_WORKGROUP_SIZE_` if defined, and something like 1 otherwise |
| 38 | +template<class BaseAccessor, typename AccessType=uint32_t, typename IndexType=uint32_t, typename Strides=pair<integral_constant<IndexType,1>,integral_constant<IndexType,_NBL_HLSL_WORKGROUP_SIZE_> > > |
| 39 | +struct MemoryAdaptor // TODO: rename to something nicer like StructureOfArrays and add a `namespace accessor_adaptors` |
| 40 | +{ |
| 41 | + using access_t = AccessType; |
| 42 | + using index_t = IndexType; |
| 43 | + NBL_CONSTEXPR index_t ElementStride = Strides::first_type::value; |
| 44 | + NBL_CONSTEXPR index_t SubElementStride = Strides::second_type::value; |
| 45 | + |
17 | 46 | BaseAccessor accessor;
|
18 | 47 |
|
19 |
| - // TODO: template atomic... then add static_asserts of `has_method<BaseAccessor,signature>::value`, do vectors and matrices in terms of each other |
20 |
| - uint get(const uint ix) |
| 48 | + access_t get(const index_t ix) |
21 | 49 | {
|
22 |
| - uint retVal; |
23 |
| - accessor.get(ix, retVal); |
| 50 | + access_t retVal; |
| 51 | + get<access_t>(ix,retVal); |
24 | 52 | return retVal;
|
25 | 53 | }
|
26 | 54 |
|
27 |
| - template<typename Scalar> |
28 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(Scalar) value) |
| 55 | + template<typename T> |
| 56 | + enable_if_t<sizeof(T)%sizeof(access_t)==0,void> get(const index_t ix, NBL_REF_ARG(T) value) |
29 | 57 | {
|
30 |
| - uint32_t aux; |
31 |
| - accessor.get(ix, aux); |
32 |
| - value = bit_cast<Scalar, uint32_t>(aux); |
33 |
| - } |
34 |
| - template<typename Scalar> |
35 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) |
36 |
| - { |
37 |
| - uint32_t2 aux; |
38 |
| - accessor.get(ix, aux.x); |
39 |
| - accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, aux.y); |
40 |
| - value = bit_cast<vector<Scalar, 2>, uint32_t2>(aux); |
41 |
| - } |
42 |
| - template<typename Scalar> |
43 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) |
44 |
| - { |
45 |
| - uint32_t3 aux; |
46 |
| - accessor.get(ix, aux.x); |
47 |
| - accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, aux.y); |
48 |
| - accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, aux.z); |
49 |
| - value = bit_cast<vector<Scalar, 3>, uint32_t3>(aux); |
50 |
| - } |
51 |
| - template<typename Scalar> |
52 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) |
53 |
| - { |
54 |
| - uint32_t4 aux; |
55 |
| - accessor.get(ix, aux.x); |
56 |
| - accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, aux.y); |
57 |
| - accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, aux.z); |
58 |
| - accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, aux.w); |
59 |
| - value = bit_cast<vector<Scalar, 3>, uint32_t4>(aux); |
| 58 | + NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t); |
| 59 | + access_t aux[SubElementCount]; |
| 60 | + for (uint64_t i=0; i<SubElementCount; i++) |
| 61 | + accessor.get(ix*ElementStride+i*SubElementStride,aux[i]); |
| 62 | + value = bit_cast<T,access_t[SubElementCount]>(aux); |
60 | 63 | }
|
61 | 64 |
|
62 |
| - template<typename Scalar> |
63 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const Scalar value) {accessor.set(ix, asuint(value));} |
64 |
| - template<typename Scalar> |
65 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 2> value) { |
66 |
| - accessor.set(ix, asuint(value.x)); |
67 |
| - accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); |
68 |
| - } |
69 |
| - template<typename Scalar> |
70 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 3> value) { |
71 |
| - accessor.set(ix, asuint(value.x)); |
72 |
| - accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); |
73 |
| - accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z)); |
74 |
| - } |
75 |
| - template<typename Scalar> |
76 |
| - enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 4> value) { |
77 |
| - accessor.set(ix, asuint(value.x)); |
78 |
| - accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); |
79 |
| - accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z)); |
80 |
| - accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.w)); |
| 65 | + template<typename T> |
| 66 | + enable_if_t<sizeof(T)%sizeof(access_t)==0,void> set(const index_t ix, NBL_CONST_REF_ARG(T) value) |
| 67 | + { |
| 68 | + NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t); |
| 69 | + access_t aux[SubElementCount] = bit_cast<access_t[SubElementCount],T>(value); |
| 70 | + for (uint64_t i=0; i<SubElementCount; i++) |
| 71 | + accessor.set(ix*ElementStride+i*SubElementStride,aux[i]); |
81 | 72 | }
|
82 | 73 |
|
83 |
| - void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
84 |
| - orig = accessor.atomicAnd(ix, value); |
85 |
| - } |
86 |
| - void atomicAnd(const uint ix, const int value, NBL_REF_ARG(int) orig) { |
87 |
| - orig = asint(accessor.atomicAnd(ix, asuint(value))); |
88 |
| - } |
89 |
| - void atomicAnd(const uint ix, const float value, NBL_REF_ARG(float) orig) { |
90 |
| - orig = asfloat(accessor.atomicAnd(ix, asuint(value))); |
91 |
| - } |
92 |
| - void atomicOr(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
93 |
| - orig = accessor.atomicOr(ix, value); |
94 |
| - } |
95 |
| - void atomicOr(const uint ix, const int value, NBL_REF_ARG(int) orig) { |
96 |
| - orig = asint(accessor.atomicOr(ix, asuint(value))); |
97 |
| - } |
98 |
| - void atomicOr(const uint ix, const float value, NBL_REF_ARG(float) orig) { |
99 |
| - orig = asfloat(accessor.atomicOr(ix, asuint(value))); |
100 |
| - } |
101 |
| - void atomicXor(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
102 |
| - orig = accessor.atomicXor(ix, value); |
| 74 | + template<typename T, typename S=BaseAccessor> |
| 75 | + enable_if_t< |
| 76 | + sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicExchange<S,index_t,access_t>::return_type,access_t>,void |
| 77 | + > atomicExchange(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 78 | + { |
| 79 | + orig = bit_cast<T,access_t>(accessor.atomicExchange(ix,bit_cast<access_t,T>(value))); |
103 | 80 | }
|
104 |
| - void atomicXor(const uint ix, const int value, NBL_REF_ARG(int) orig) { |
105 |
| - orig = asint(accessor.atomicXor(ix, asuint(value))); |
| 81 | + template<typename T, typename S=BaseAccessor> |
| 82 | + enable_if_t< |
| 83 | + sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicCompSwap<S,index_t,access_t,access_t>::return_type,access_t>,void |
| 84 | + > atomicCompSwap(const index_t ix, const T value, const T comp, NBL_REF_ARG(T) orig) |
| 85 | + { |
| 86 | + orig = bit_cast<T,access_t>(accessor.atomicCompSwap(ix,bit_cast<access_t,T>(comp),bit_cast<access_t,T>(value))); |
106 | 87 | }
|
107 |
| - void atomicXor(const uint ix, const float value, NBL_REF_ARG(float) orig) { |
108 |
| - orig = asfloat(accessor.atomicXor(ix, asuint(value))); |
| 88 | + |
| 89 | + template<typename T, typename S=BaseAccessor> |
| 90 | + enable_if_t< |
| 91 | + sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicAnd<S,index_t,access_t>::return_type,access_t>,void |
| 92 | + > atomicAnd(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 93 | + { |
| 94 | + orig = bit_cast<T,access_t>(accessor.atomicAnd(ix,bit_cast<access_t,T>(value))); |
109 | 95 | }
|
110 |
| - void atomicAdd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
111 |
| - orig = accessor.atomicAdd(ix, value); |
| 96 | + template<typename T, typename S=BaseAccessor> |
| 97 | + enable_if_t< |
| 98 | + sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicOr<S,index_t,access_t>::return_type,access_t>,void |
| 99 | + > atomicOr(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 100 | + { |
| 101 | + orig = bit_cast<T,access_t>(accessor.atomicOr(ix,bit_cast<access_t,T>(value))); |
112 | 102 | }
|
113 |
| - void atomicMin(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
114 |
| - orig = accessor.atomicMin(ix, value); |
| 103 | + template<typename T, typename S=BaseAccessor> |
| 104 | + enable_if_t< |
| 105 | + sizeof(T)==sizeof(access_t) && is_same_v<S,BaseAccessor> && is_same_v<has_method_atomicXor<S,index_t,access_t>::return_type,access_t>,void |
| 106 | + > atomicXor(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 107 | + { |
| 108 | + orig = bit_cast<T,access_t>(accessor.atomicXor(ix,bit_cast<access_t,T>(value))); |
115 | 109 | }
|
116 |
| - void atomicMax(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
117 |
| - orig = accessor.atomicMax(ix, value); |
| 110 | + |
| 111 | + // This has the upside of never calling a `(uint32_t)(uint32_t,uint32_t)` overload of `atomicAdd` because it checks the return type! |
| 112 | + // If someone makes a `(float)(uint32_t,uint32_t)` they will break this detection code, but oh well. |
| 113 | + template<typename T> |
| 114 | + enable_if_t<is_same_v<has_method_atomicAdd<BaseAccessor,index_t,T>::return_type,T>,void> atomicAdd(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 115 | + { |
| 116 | + orig = accessor.atomicAdd(ix,value); |
118 | 117 | }
|
119 |
| - void atomicExchange(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { |
120 |
| - orig = accessor.atomicExchange(ix, value); |
| 118 | + template<typename T> |
| 119 | + enable_if_t<is_same_v<has_method_atomicMin<BaseAccessor,index_t,T>::return_type,T>,void> atomicMin(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 120 | + { |
| 121 | + orig = accessor.atomicMin(ix,value); |
121 | 122 | }
|
122 |
| - void atomicCompSwap(const uint ix, const uint value, const uint comp, NBL_REF_ARG(uint) orig) { |
123 |
| - orig = accessor.atomicCompSwap(ix, comp, value); |
| 123 | + template<typename T> |
| 124 | + enable_if_t<is_same_v<has_method_atomicMax<BaseAccessor,index_t,T>::return_type,T>,void> atomicMax(const index_t ix, const T value, NBL_REF_ARG(T) orig) |
| 125 | + { |
| 126 | + orig = accessor.atomicMax(ix,value); |
124 | 127 | }
|
125 | 128 |
|
126 |
| - // TODO: figure out the `enable_if` syntax for this |
127 |
| - void workgroupExecutionAndMemoryBarrier() { |
| 129 | + template<typename S=BaseAccessor> |
| 130 | + enable_if_t< |
| 131 | + is_same_v<S,BaseAccessor> && is_same_v<has_method_workgroupExecutionAndMemoryBarrier<S>::return_type,void>,void |
| 132 | + > workgroupExecutionAndMemoryBarrier() |
| 133 | + { |
128 | 134 | accessor.workgroupExecutionAndMemoryBarrier();
|
129 | 135 | }
|
130 | 136 | };
|
|
0 commit comments