Skip to content

Commit 95287fd

Browse files
committed
Moved fma function
1 parent cc5644c commit 95287fd

File tree

4 files changed

+40
-39
lines changed

4 files changed

+40
-39
lines changed

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
103103
struct nMax_helper;
104104
template<typename T NBL_STRUCT_CONSTRAINABLE>
105105
struct nClamp_helper;
106-
106+
template<typename T NBL_STRUCT_CONSTRAINABLE>
107+
struct fma_helper;
107108

108109
#ifdef __HLSL_VERSION // HLSL only specializations
109110

@@ -163,6 +164,7 @@ template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(refract_hel
163164
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(nMax_helper, nMax, (T), (T)(T), T)
164165
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(nMin_helper, nMin, (T), (T)(T), T)
165166
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(nClamp_helper, nClamp, (T), (T)(T), T)
167+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(fma_helper, fma, (T), (T)(T)(T), T)
166168

167169
#define BITCOUNT_HELPER_RETRUN_TYPE conditional_t<is_vector_v<T>, vector<int32_t, vector_traits<T>::Dimension>, int32_t>
168170
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitCount_helper, bitCount, (T), (T), BITCOUNT_HELPER_RETRUN_TYPE)
@@ -600,6 +602,16 @@ struct nClamp_helper<T>
600602
}
601603
};
602604

605+
template<typename FloatingPoint>
606+
requires concepts::FloatingPointScalar<FloatingPoint>
607+
struct fma_helper<FloatingPoint>
608+
{
609+
static FloatingPoint __call(NBL_CONST_REF_ARG(FloatingPoint) x, NBL_CONST_REF_ARG(FloatingPoint) y, NBL_CONST_REF_ARG(FloatingPoint) z)
610+
{
611+
return std::fma(x, y, z);
612+
}
613+
};
614+
603615
#endif // C++ only specializations
604616

605617
// C++ and HLSL specializations
@@ -897,6 +909,25 @@ struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT(DOT_HELPER_REQUIREMENT) >
897909

898910
#undef DOT_HELPER_REQUIREMENT
899911

912+
template<typename T>
913+
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT)
914+
struct fma_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
915+
{
916+
using return_t = T;
917+
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
918+
{
919+
using traits = hlsl::vector_traits<T>;
920+
array_get<T, typename traits::scalar_type> getter;
921+
array_set<T, typename traits::scalar_type> setter;
922+
923+
return_t output;
924+
for (uint32_t i = 0; i < traits::Dimension; ++i)
925+
setter(output, i, fma_helper<typename traits::scalar_type>::__call(getter(x, i), getter(y, i), getter(z, i)));
926+
927+
return output;
928+
}
929+
};
930+
900931
}
901932
}
902933
}

include/nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ inline int32_t2 unpackDouble2x32(T val)
295295
return NAMESPACE::unpackDouble2x32(val);
296296
}
297297

298+
template<typename T>
299+
inline T fma(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
300+
{
301+
return cpp_compat_intrinsics_impl::fma_helper<T>::__call(x, y, z);
302+
}
303+
298304
#undef NAMESPACE
299305

300306
}

include/nbl/builtin/hlsl/tgmath.hlsl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <nbl/builtin/hlsl/spirv_intrinsics/core.hlsl>
1414
#include <nbl/builtin/hlsl/concepts/core.hlsl>
1515
#include <nbl/builtin/hlsl/concepts/vector.hlsl>
16+
#include <nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl>
17+
1618
// C++ headers
1719
#ifndef __HLSL_VERSION
1820
#include <algorithm>
@@ -211,12 +213,6 @@ inline T ceil(NBL_CONST_REF_ARG(T) val)
211213
return tgmath_impl::ceil_helper<T>::__call(val);
212214
}
213215

214-
template<typename T>
215-
inline T fma(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
216-
{
217-
return tgmath_impl::fma_helper<T>::__call(x, y, z);
218-
}
219-
220216
template<typename T, typename U>
221217
inline T ldexp(NBL_CONST_REF_ARG(T) arg, NBL_CONST_REF_ARG(U) exp)
222218
{

include/nbl/builtin/hlsl/tgmath/impl.hlsl

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
8383
struct trunc_helper;
8484
template<typename T NBL_STRUCT_CONSTRAINABLE>
8585
struct ceil_helper;
86-
template<typename T NBL_STRUCT_CONSTRAINABLE>
87-
struct fma_helper;
8886
template<typename T, typename U NBL_STRUCT_CONSTRAINABLE>
8987
struct ldexp_helper;
9088
template<typename T NBL_STRUCT_CONSTRAINABLE>
@@ -138,7 +136,6 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(roundEven_helper, round
138136
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(trunc_helper, trunc, (T), (T), T)
139137
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(ceil_helper, ceil, (T), (T), T)
140138
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(pow_helper, pow, (T), (T)(T), T)
141-
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(fma_helper, fma, (T), (T)(T)(T), T)
142139
template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(ldexp_helper, ldexp, (T)(U), (T)(U), T)
143140
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(modfStruct_helper, modfStruct, (T), (T), ModfOutput<T>)
144141
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(frexpStruct_helper, frexpStruct, (T), (T), FrexpOutput<T>)
@@ -337,16 +334,6 @@ struct roundEven_helper<FloatingPoint NBL_PARTIAL_REQ_BOT(concepts::FloatingPoin
337334
}
338335
};
339336

340-
template<typename FloatingPoint>
341-
NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<FloatingPoint>)
342-
struct fma_helper<FloatingPoint NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<FloatingPoint>) >
343-
{
344-
static FloatingPoint __call(NBL_CONST_REF_ARG(FloatingPoint) x, NBL_CONST_REF_ARG(FloatingPoint) y, NBL_CONST_REF_ARG(FloatingPoint) z)
345-
{
346-
return std::fma(x, y, z);
347-
}
348-
};
349-
350337
template<typename T, typename U>
351338
NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<T> && concepts::IntegralScalar<U>)
352339
struct ldexp_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T> && concepts::IntegralScalar<U>) >
@@ -510,25 +497,6 @@ struct pow_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
510497
}
511498
};
512499

513-
template<typename T>
514-
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT)
515-
struct fma_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
516-
{
517-
using return_t = T;
518-
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
519-
{
520-
using traits = hlsl::vector_traits<T>;
521-
array_get<T, typename traits::scalar_type> getter;
522-
array_set<T, typename traits::scalar_type> setter;
523-
524-
return_t output;
525-
for (uint32_t i = 0; i < traits::Dimension; ++i)
526-
setter(output, i, fma_helper<typename traits::scalar_type>::__call(getter(x, i), getter(y, i), getter(z, i)));
527-
528-
return output;
529-
}
530-
};
531-
532500
template<typename T, typename U>
533501
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT && (vector_traits<T>::Dimension == vector_traits<U>::Dimension))
534502
struct ldexp_helper<T, U NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT && (vector_traits<T>::Dimension == vector_traits<U>::Dimension)) >

0 commit comments

Comments
 (0)