Skip to content

Commit cc5644c

Browse files
committed
Added dot HLSL specialization
1 parent be5ef6b commit cc5644c

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -600,25 +600,6 @@ struct nClamp_helper<T>
600600
}
601601
};
602602

603-
template<typename Vectorial>
604-
NBL_PARTIAL_REQ_TOP(concepts::Vectorial<Vectorial>)
605-
struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT(concepts::Vectorial<Vectorial>) >
606-
{
607-
using scalar_type = typename vector_traits<Vectorial>::scalar_type;
608-
609-
static inline scalar_type __call(NBL_CONST_REF_ARG(Vectorial) lhs, NBL_CONST_REF_ARG(Vectorial) rhs)
610-
{
611-
static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
612-
static array_get<Vectorial, scalar_type> getter;
613-
614-
scalar_type retval = getter(lhs, 0) * getter(rhs, 0);
615-
for (uint32_t i = 1; i < ArrayDim; ++i)
616-
retval = retval + getter(lhs, i) * getter(rhs, i);
617-
618-
return retval;
619-
}
620-
};
621-
622603
#endif // C++ only specializations
623604

624605
// C++ and HLSL specializations
@@ -889,6 +870,33 @@ struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::B
889870
}
890871
};
891872

873+
#ifdef __HLSL_VERSION
874+
#define DOT_HELPER_REQUIREMENT (concepts::Vectorial<Vectorial> && !is_vector_v<Vectorial>)
875+
#else
876+
#define DOT_HELPER_REQUIREMENT concepts::Vectorial<Vectorial>
877+
#endif
878+
879+
template<typename Vectorial>
880+
NBL_PARTIAL_REQ_TOP(DOT_HELPER_REQUIREMENT)
881+
struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT(DOT_HELPER_REQUIREMENT) >
882+
{
883+
using scalar_type = typename vector_traits<Vectorial>::scalar_type;
884+
885+
static inline scalar_type __call(NBL_CONST_REF_ARG(Vectorial) lhs, NBL_CONST_REF_ARG(Vectorial) rhs)
886+
{
887+
static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
888+
static array_get<Vectorial, scalar_type> getter;
889+
890+
scalar_type retval = getter(lhs, 0) * getter(rhs, 0);
891+
for (uint32_t i = 1; i < ArrayDim; ++i)
892+
retval = retval + getter(lhs, i) * getter(rhs, i);
893+
894+
return retval;
895+
}
896+
};
897+
898+
#undef DOT_HELPER_REQUIREMENT
899+
892900
}
893901
}
894902
}

include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ template<typename T NBL_FUNC_REQUIRES(is_floating_point_v<T> && is_vector_v<T>)
318318
[[vk::ext_instruction(spv::OpIsInf)]]
319319
vector<bool, vector_traits<T>::Dimension> isInf(T val);
320320

321-
template<typename Vector NBL_FUNC_REQUIRES(is_vector_v<T>)
321+
template<typename Vector NBL_FUNC_REQUIRES(is_vector_v<Vector>)
322322
[[vk::ext_instruction( spv::OpDot )]]
323323
typename vector_traits<Vector>::scalar_type dot(Vector lhs, Vector rhs);
324324

0 commit comments

Comments
 (0)