Skip to content

Commit f6a69fe

Browse files
Use FMA for unspecialized dot product in intrinsics_impl.hlsl
1 parent 95287fd commit f6a69fe

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,25 @@ struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::B
882882
}
883883
};
884884

885+
template<typename T>
886+
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT)
887+
struct fma_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
888+
{
889+
using return_t = T;
890+
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
891+
{
892+
using traits = hlsl::vector_traits<T>;
893+
array_get<T, typename traits::scalar_type> getter;
894+
array_set<T, typename traits::scalar_type> setter;
895+
896+
return_t output;
897+
for (uint32_t i = 0; i < traits::Dimension; ++i)
898+
setter(output, i, fma_helper<typename traits::scalar_type>::__call(getter(x, i), getter(y, i), getter(z, i)));
899+
900+
return output;
901+
}
902+
};
903+
885904
#ifdef __HLSL_VERSION
886905
#define DOT_HELPER_REQUIREMENT (concepts::Vectorial<Vectorial> && !is_vector_v<Vectorial>)
887906
#else
@@ -901,35 +920,16 @@ struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT(DOT_HELPER_REQUIREMENT) >
901920

902921
scalar_type retval = getter(lhs, 0) * getter(rhs, 0);
903922
for (uint32_t i = 1; i < ArrayDim; ++i)
904-
retval = retval + getter(lhs, i) * getter(rhs, i);
923+
retval = fma_helper<scalar_type>::__call(getter(lhs, i), getter(rhs, i), retval);
905924

906925
return retval;
907926
}
908927
};
909928

910929
#undef DOT_HELPER_REQUIREMENT
911930

912-
template<typename T>
913-
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT)
914-
struct fma_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
915-
{
916-
using return_t = T;
917-
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
918-
{
919-
using traits = hlsl::vector_traits<T>;
920-
array_get<T, typename traits::scalar_type> getter;
921-
array_set<T, typename traits::scalar_type> setter;
922-
923-
return_t output;
924-
for (uint32_t i = 0; i < traits::Dimension; ++i)
925-
setter(output, i, fma_helper<typename traits::scalar_type>::__call(getter(x, i), getter(y, i), getter(z, i)));
926-
927-
return output;
928-
}
929-
};
930-
931931
}
932932
}
933933
}
934934

935-
#endif
935+
#endif

0 commit comments

Comments
 (0)