Skip to content

Commit 43b271d

Browse files
committed
Refactored AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER macro
1 parent 9d8de81 commit 43b271d

File tree

3 files changed

+147
-217
lines changed

3 files changed

+147
-217
lines changed

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

Lines changed: 80 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
#include <nbl/builtin/hlsl/concepts/matrix.hlsl>
1313
#include <nbl/builtin/hlsl/cpp_compat/promote.hlsl>
1414
#include <nbl/builtin/hlsl/numbers.hlsl>
15+
#ifndef __HLSL_VERSION
16+
#include <boost/preprocessor/comparison/not_equal.hpp>
17+
#include <boost/preprocessor/punctuation/comma_if.hpp>
18+
#include <boost/preprocessor/seq/for_each_i.hpp>
19+
#endif
1520

1621
namespace nbl
1722
{
@@ -82,93 +87,68 @@ struct refract_helper;
8287
#ifdef __HLSL_VERSION // HLSL only specializations
8388

8489
// it is crucial these partial specializations appear first because thats what makes the helpers match SPIR-V intrinsics first
85-
#define AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(HELPER_NAME, SPIRV_FUNCTION_NAME, RETURN_TYPE)\
86-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::SPIRV_FUNCTION_NAME<T>(experimental::declval<T>()))>)\
87-
struct HELPER_NAME<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::SPIRV_FUNCTION_NAME<T>(experimental::declval<T>()))>) >\
90+
91+
#define DECLVAL(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) experimental::declval<_T>()
92+
#define DECL_ARG(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) const _T arg##i
93+
#define WRAP(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) _T
94+
#define ARG(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) arg##i
95+
96+
// the template<> needs to be written ourselves
97+
// return type is __VA_ARGS__ to protect against `,` in templated return types
98+
#define AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(HELPER_NAME, SPIRV_FUNCTION_NAME, ARG_TYPE_LIST, ARG_TYPE_SET, ...)\
99+
NBL_PARTIAL_REQ_TOP(is_same_v<decltype(spirv::SPIRV_FUNCTION_NAME<T>(BOOST_PP_SEQ_FOR_EACH_I(DECLVAL, _, ARG_TYPE_SET))), __VA_ARGS__ >) \
100+
struct HELPER_NAME<BOOST_PP_SEQ_FOR_EACH_I(WRAP, _, ARG_TYPE_LIST) NBL_PARTIAL_REQ_BOT(is_same_v<decltype(spirv::SPIRV_FUNCTION_NAME<T>(BOOST_PP_SEQ_FOR_EACH_I(DECLVAL, _, ARG_TYPE_SET))), __VA_ARGS__ >) >\
88101
{\
89-
using return_t = RETURN_TYPE;\
90-
static inline return_t __call(const T arg)\
102+
using return_t = __VA_ARGS__;\
103+
static inline return_t __call( BOOST_PP_SEQ_FOR_EACH_I(DECL_ARG, _, ARG_TYPE_SET) )\
91104
{\
92-
return spirv::SPIRV_FUNCTION_NAME<T>(arg);\
105+
return spirv::SPIRV_FUNCTION_NAME<T>( BOOST_PP_SEQ_FOR_EACH_I(ARG, _, ARG_TYPE_SET) );\
93106
}\
94107
};
95108

96-
#define FIND_MSB_LSB_RETURN_TYPE conditional_t<is_vector_v<T>, vector<int32_t, vector_traits<T>::Dimension>, int32_t>;
97-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_msb_helper, findUMsb, FIND_MSB_LSB_RETURN_TYPE)
98-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_msb_helper, findSMsb, FIND_MSB_LSB_RETURN_TYPE)
99-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_lsb_helper, findILsb, FIND_MSB_LSB_RETURN_TYPE)
109+
#define FIND_MSB_LSB_RETURN_TYPE conditional_t<is_vector_v<T>, vector<int32_t, vector_traits<T>::Dimension>, int32_t>
110+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_msb_helper, findUMsb, (T), (T), FIND_MSB_LSB_RETURN_TYPE);
111+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_msb_helper, findSMsb, (T), (T), FIND_MSB_LSB_RETURN_TYPE)
112+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_lsb_helper, findILsb, (T), (T), FIND_MSB_LSB_RETURN_TYPE)
100113
#undef FIND_MSB_LSB_RETURN_TYPE
101114

102-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitReverse_helper, bitReverse, T)
103-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(transpose_helper, transpose, T)
104-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(length_helper, length, typename vector_traits<T>::scalar_type)
105-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(normalize_helper, normalize, T)
106-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(rsqrt_helper, inverseSqrt, T)
107-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(frac_helper, fract, T)
108-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(inverse_helper, matrixInverse, T)
109-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(all_helper, any, T)
110-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(any_helper, any, T)
111-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(sign_helper, fSign, T)
112-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(sign_helper, sSign, T)
113-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(radians_helper, radians, T)
114-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(degrees_helper, degrees, T)
115+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitReverse_helper, bitReverse, (T), (T), T)
116+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(transpose_helper, transpose, (T), (T), T)
117+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(length_helper, length, (T), (T), typename vector_traits<T>::scalar_type)
118+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(normalize_helper, normalize, (T), (T), T)
119+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(rsqrt_helper, inverseSqrt, (T), (T), T)
120+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(frac_helper, fract, (T), (T), T)
121+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(all_helper, any, (T), (T), T)
122+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(any_helper, any, (T), (T), T)
123+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(sign_helper, fSign, (T), (T), T)
124+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(sign_helper, sSign, (T), (T), T)
125+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(radians_helper, radians, (T), (T), T)
126+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(degrees_helper, degrees, (T), (T), T)
127+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(max_helper, fMax, (T), (T)(T), T)
128+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(max_helper, uMax, (T), (T)(T), T)
129+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(max_helper, sMax, (T), (T)(T), T)
130+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(min_helper, fMin, (T), (T)(T), T)
131+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(min_helper, uMin, (T), (T)(T), T)
132+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(min_helper, sMin, (T), (T)(T), T)
133+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(step_helper, step, (T), (T)(T), T)
134+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(reflect_helper, reflect, (T), (T)(T), T)
135+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(clamp_helper, fClamp, (T), (T)(T)(T), T)
136+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(clamp_helper, uClamp, (T), (T)(T)(T), T)
137+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(clamp_helper, sClamp, (T), (T)(T)(T), T)
138+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(smoothStep_helper, smoothStep, (T), (T)(T)(T), T)
139+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(faceForward_helper, faceForward, (T), (T)(T)(T), T)
140+
template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(refract_helper, refract, (T)(U), (T)(T)(U), T)
115141

116142
#define BITCOUNT_HELPER_RETRUN_TYPE conditional_t<is_vector_v<T>, vector<int32_t, vector_traits<T>::Dimension>, int32_t>
117-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitCount_helper, bitCount, BITCOUNT_HELPER_RETRUN_TYPE)
143+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitCount_helper, bitCount, (T), (T), BITCOUNT_HELPER_RETRUN_TYPE)
118144
#undef BITCOUNT_HELPER_RETRUN_TYPE
119145

146+
#undef DECLVAL
147+
#undef DECL_ARG
148+
#undef WRAP
149+
#undef ARG
120150
#undef AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER
121151

122-
#define AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(HELPER_NAME, SPIRV_FUNCTION_NAME, RETURN_TYPE)\
123-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::SPIRV_FUNCTION_NAME<T>(experimental::declval<T>(), experimental::declval<T>()))>)\
124-
struct HELPER_NAME<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::SPIRV_FUNCTION_NAME<T>(experimental::declval<T>(), experimental::declval<T>()))>) >\
125-
{\
126-
using return_t = RETURN_TYPE;\
127-
static inline return_t __call(const T a, const T b)\
128-
{\
129-
return spirv::SPIRV_FUNCTION_NAME<T>(a, b);\
130-
}\
131-
};
132-
133-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(max_helper, fMax, T)
134-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(max_helper, uMax, T)
135-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(max_helper, sMax, T)
136-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(min_helper, fMin, T)
137-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(min_helper, uMin, T)
138-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(min_helper, sMin, T)
139-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(step_helper, step, T)
140-
AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC(reflect_helper, reflect, T)
141-
142-
#undef AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER_2_ARG_FUNC
143-
144-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::fClamp<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>)
145-
struct clamp_helper<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::fClamp<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>) >
146-
{
147-
using return_t = T;
148-
static return_t __call(NBL_CONST_REF_ARG(T) val, NBL_CONST_REF_ARG(T) _min, NBL_CONST_REF_ARG(T) _max)
149-
{
150-
return spirv::fClamp(val, _min, _max);
151-
}
152-
};
153-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::uClamp<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>)
154-
struct clamp_helper<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::uClamp<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>) >
155-
{
156-
using return_t = T;
157-
static return_t __call(NBL_CONST_REF_ARG(T) val, NBL_CONST_REF_ARG(T) _min, NBL_CONST_REF_ARG(T) _max)
158-
{
159-
return spirv::uClamp(val, _min, _max);
160-
}
161-
};
162-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::sClamp<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>)
163-
struct clamp_helper<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::sClamp<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>) >
164-
{
165-
using return_t = T;
166-
static return_t __call(NBL_CONST_REF_ARG(T) val, NBL_CONST_REF_ARG(T) _min, NBL_CONST_REF_ARG(T) _max)
167-
{
168-
return spirv::sClamp(val, _min, _max);
169-
}
170-
};
171-
172152
template<typename UInt64> NBL_PARTIAL_REQ_TOP(is_same_v<UInt64, uint64_t>)
173153
struct find_msb_helper<UInt64 NBL_PARTIAL_REQ_BOT(is_same_v<UInt64, uint64_t>) >
174154
{
@@ -234,26 +214,6 @@ struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::fMix<T>(e
234214
}
235215
};
236216

237-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::smoothStep<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>)
238-
struct smoothStep_helper<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::smoothStep<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>) >
239-
{
240-
using return_t = T;
241-
static inline return_t __call(const T edge0, const T edge1, const T x)
242-
{
243-
return spirv::smoothStep<T>(edge0, edge1, x);
244-
}
245-
};
246-
247-
template<typename T> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::faceForward<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>)
248-
struct faceForward_helper<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::faceForward<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<T>()))>) >
249-
{
250-
using return_t = T;
251-
static inline return_t __call(const T N, const T I, const T Nref)
252-
{
253-
return spirv::faceForward<T>(N, I, Nref);
254-
}
255-
};
256-
257217
template<typename SquareMatrix> NBL_PARTIAL_REQ_TOP(matrix_traits<SquareMatrix>::Square)
258218
struct determinant_helper<SquareMatrix NBL_PARTIAL_REQ_BOT(matrix_traits<SquareMatrix>::Square) >
259219
{
@@ -263,28 +223,34 @@ struct determinant_helper<SquareMatrix NBL_PARTIAL_REQ_BOT(matrix_traits<SquareM
263223
}
264224
};
265225

266-
template<typename T, typename U> NBL_PARTIAL_REQ_TOP(always_true<decltype(spirv::refract<T, U>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<U>()))>)
267-
struct refract_helper<T, U NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::refract<T, U>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<U>()))>) >
268-
{
269-
using return_t = T;
270-
static inline return_t __call(const T I, const T N, const U eta)
271-
{
272-
return spirv::refract<T>(I, N, eta);
273-
}
274-
};
275-
276226
#else // C++ only specializations
277227

278-
template<typename T>
279-
requires concepts::Scalar<T>
280-
struct clamp_helper<T>
281-
{
282-
using return_t = T;
283-
static inline return_t __call(const T val, const T min, const T max)
284-
{
285-
return std::clamp<T>(val, min, max);
286-
}
228+
#define DECL_ARG(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) const _T arg##i
229+
#define WRAP(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) _T
230+
#define ARG(r,data,i,_T) BOOST_PP_COMMA_IF(BOOST_PP_NOT_EQUAL(i,0)) arg##i
231+
232+
// the template<> needs to be written ourselves
233+
// return type is __VA_ARGS__ to protect against `,` in templated return types
234+
#define AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(HELPER_NAME, STD_FUNCTION_NAME, REQUIREMENT, ARG_TYPE_LIST, ARG_TYPE_SET, ...)\
235+
requires REQUIREMENT \
236+
struct HELPER_NAME<BOOST_PP_SEQ_FOR_EACH_I(WRAP, _, ARG_TYPE_LIST)>\
237+
{\
238+
using return_t = __VA_ARGS__;\
239+
static inline return_t __call( BOOST_PP_SEQ_FOR_EACH_I(DECL_ARG, _, ARG_TYPE_SET) )\
240+
{\
241+
return std::STD_FUNCTION_NAME<BOOST_PP_SEQ_FOR_EACH_I(WRAP, _, ARG_TYPE_LIST)>( BOOST_PP_SEQ_FOR_EACH_I(ARG, _, ARG_TYPE_SET) );\
242+
}\
287243
};
244+
245+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(clamp_helper, clamp, concepts::Scalar<T>, (T), (T)(T)(T), T)
246+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(max_helper, max, concepts::Scalar<T>, (T), (T)(T), T)
247+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(min_helper, min, concepts::Scalar<T>, (T), (T)(T), T)
248+
249+
#undef DECL_ARG
250+
#undef WRAP
251+
#undef ARG
252+
#undef AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER
253+
288254
template<typename T>
289255
requires concepts::IntegralScalar<T>
290256
struct bitReverse_helper<T>
@@ -323,24 +289,7 @@ struct normalize_helper<Vectorial>
323289
return vec / length_helper<Vectorial>::__call(vec);
324290
}
325291
};
326-
template<typename T>
327-
requires concepts::Scalar<T>
328-
struct max_helper<T>
329-
{
330-
static T __call(NBL_CONST_REF_ARG(T) a, NBL_CONST_REF_ARG(T) b)
331-
{
332-
return std::max<T>(a, b);
333-
}
334-
};
335-
template<typename T>
336-
requires concepts::Scalar<T>
337-
struct min_helper<T>
338-
{
339-
static T __call(NBL_CONST_REF_ARG(T) a, NBL_CONST_REF_ARG(T) b)
340-
{
341-
return std::min<T>(a, b);
342-
}
343-
};
292+
344293
template<typename T>
345294
requires concepts::IntegralScalar<T>
346295
struct find_lsb_helper<T>

0 commit comments

Comments
 (0)