Skip to content

Commit c72a85d

Browse files
authored
[SYCL][ESIMD] Fix compilation issue of abs function (#6443)
* Fix compilation issue of abs function
1 parent 963e64d commit c72a85d

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

sycl/include/sycl/ext/intel/esimd/math.hpp

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -93,56 +93,53 @@ saturate(simd<T1, SZ> src) {
9393
// abs
9494
namespace detail {
9595

96-
template <typename T0, typename T1, int SZ>
97-
ESIMD_NODEBUG ESIMD_INLINE simd<T0, SZ>
98-
__esimd_abs_common_internal(simd<T1, SZ> src0) {
99-
simd<T1, SZ> Result = simd<T0, SZ>(__esimd_abs<T1, SZ>(src0.data()));
100-
return Result;
96+
template <typename TRes, typename TArg, int SZ>
97+
ESIMD_NODEBUG ESIMD_INLINE simd<TRes, SZ>
98+
__esimd_abs_common_internal(simd<TArg, SZ> src0) {
99+
simd<TArg, SZ> Result = simd<TArg, SZ>(__esimd_abs<TArg, SZ>(src0.data()));
100+
return convert<TRes>(Result);
101101
}
102102

103-
template <typename T0, typename T1>
103+
template <typename TRes, typename TArg>
104104
ESIMD_NODEBUG
105-
ESIMD_INLINE std::enable_if_t<detail::is_esimd_scalar<T0>::value &&
106-
detail::is_esimd_scalar<T1>::value,
107-
std::remove_const_t<T0>>
108-
__esimd_abs_common_internal(T1 src0) {
109-
using TT0 = std::remove_const_t<T0>;
110-
using TT1 = std::remove_const_t<T1>;
111-
112-
simd<TT1, 1> Src0 = src0;
113-
simd<TT0, 1> Result = __esimd_abs_common_internal<TT0>(Src0);
114-
return Result[0];
105+
ESIMD_INLINE std::enable_if_t<detail::is_esimd_scalar<TRes>::value &&
106+
detail::is_esimd_scalar<TArg>::value,
107+
TRes>
108+
__esimd_abs_common_internal(TArg src0) {
109+
simd<TArg, 1> Src0 = src0;
110+
simd<TArg, 1> Result = __esimd_abs_common_internal<TArg>(Src0);
111+
return convert<TRes>(Result)[0];
115112
}
116113
} // namespace detail
117114
/// @endcond ESIMD_DETAIL
118115

119116
/// Get absolute value (vector version)
120-
/// @tparam T0 element type of the returned vector.
121-
/// @tparam T1 element type of the input vector.
117+
/// @tparam TRes element type of the returned vector.
118+
/// @tparam TArg element type of the input vector.
122119
/// @tparam SZ size of the input and returned vector.
123120
/// @param src0 the input vector.
124121
/// @return vector of absolute values.
125-
template <typename T0, typename T1, int SZ>
122+
template <typename TRes, typename TArg, int SZ>
126123
__ESIMD_API std::enable_if_t<
127-
!std::is_same<std::remove_const_t<T0>, std::remove_const_t<T1>>::value,
128-
simd<T0, SZ>>
129-
abs(simd<T1, SZ> src0) {
130-
return detail::__esimd_abs_common_internal<T0, T1, SZ>(src0.data());
124+
!std::is_same<std::remove_const_t<TRes>, std::remove_const_t<TArg>>::value,
125+
simd<TRes, SZ>>
126+
abs(simd<TArg, SZ> src0) {
127+
return detail::__esimd_abs_common_internal<TRes, TArg, SZ>(src0.data());
131128
}
132129

133130
/// Get absolute value (scalar version)
134131
/// @tparam T0 element type of the returned value.
135132
/// @tparam T1 element type of the input value.
136133
/// @param src0 the source operand.
137134
/// @return absolute value.
138-
template <typename T0, typename T1>
139-
__ESIMD_API std::enable_if_t<
140-
!std::is_same<std::remove_const_t<T0>, std::remove_const_t<T1>>::value &&
141-
detail::is_esimd_scalar<T0>::value &&
142-
detail::is_esimd_scalar<T1>::value,
143-
std::remove_const_t<T0>>
144-
abs(T1 src0) {
145-
return detail::__esimd_abs_common_internal<T0, T1>(src0);
135+
template <typename TRes, typename TArg>
136+
__ESIMD_API std::enable_if_t<!std::is_same<std::remove_const_t<TRes>,
137+
std::remove_const_t<TArg>>::value &&
138+
detail::is_esimd_scalar<TRes>::value &&
139+
detail::is_esimd_scalar<TArg>::value,
140+
std::remove_const_t<TRes>>
141+
abs(TArg src0) {
142+
return detail::__esimd_abs_common_internal<TRes, TArg>(src0);
146143
}
147144

148145
/// Get absolute value (vector version). This is a specialization of a version

sycl/include/sycl/ext/intel/esimd/simd.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ class simd : public detail::simd_obj_impl<
181181
/// element type \c To.
182182
template <typename To, typename From, int N>
183183
ESIMD_INLINE simd<To, N> convert(const simd<From, N> &val) {
184-
if constexpr (std::is_same_v<To, From>)
184+
if constexpr (std::is_same_v<std::remove_const_t<To>,
185+
std::remove_const_t<From>>)
185186
return val;
186187
else
187188
return detail::convert_vector<To, From, N>(val.data());

0 commit comments

Comments
 (0)