Skip to content

Commit b969205

Browse files
[NFCI] Refactor detail::Convert[To|From]OpenCLType (#17024)
Same result but avoids using `vec::vector_t` that is going to be removed.
1 parent 2a9ca31 commit b969205

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,22 @@ template <typename T> auto convertToOpenCLType(T &&x) {
137137
return reinterpret_cast<result_type>(x);
138138
} else if constexpr (is_vec_v<no_ref>) {
139139
using ElemTy = typename no_ref::element_type;
140-
// sycl::half may convert to _Float16, and we would try to instantiate
141-
// vec class with _Float16 DataType, which is not expected there. As
142-
// such, leave vector<half, N> as-is.
143-
using MatchingVec =
144-
vec<std::conditional_t<std::is_same_v<ElemTy, half>, ElemTy,
145-
decltype(convertToOpenCLType(
146-
std::declval<ElemTy>()))>,
147-
no_ref::size()>;
140+
using ConvertedElemTy =
141+
decltype(convertToOpenCLType(std::declval<ElemTy>()));
142+
static constexpr int NumElements = no_ref::size();
148143
#ifdef __SYCL_DEVICE_ONLY__
149-
return sycl::bit_cast<typename MatchingVec::vector_t>(x);
144+
using vector_t =
145+
std::conditional_t<NumElements == 1, ConvertedElemTy,
146+
ConvertedElemTy
147+
__attribute__((ext_vector_type(NumElements)))>;
148+
return sycl::bit_cast<vector_t>(x);
150149
#else
150+
// sycl::half may convert to _Float16, and we would try to
151+
// instantiate vec class with _Float16 DataType, which is not
152+
// expected there. As such, leave vector<half, N> as-is.
153+
using MatchingVec = vec<std::conditional_t<std::is_same_v<ElemTy, half>,
154+
ElemTy, ConvertedElemTy>,
155+
NumElements>;
151156
return x.template as<MatchingVec>();
152157
#endif
153158
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
@@ -203,7 +208,7 @@ template <typename To, typename From> auto convertFromOpenCLTypeFor(From &&x) {
203208
if constexpr (is_vec_v<To_noref> && is_vec_v<From_noref>)
204209
return x.template as<To_noref>();
205210
else if constexpr (is_vec_v<To_noref> && is_ext_vector_v<From_noref>)
206-
return To_noref{bit_cast<typename To_noref::vector_t>(x)};
211+
return bit_cast<To>(x);
207212
else
208213
return static_cast<To>(x);
209214
}

0 commit comments

Comments
 (0)