@@ -137,17 +137,22 @@ template <typename T> auto convertToOpenCLType(T &&x) {
137
137
return reinterpret_cast <result_type>(x);
138
138
} else if constexpr (is_vec_v<no_ref>) {
139
139
using ElemTy = typename no_ref::element_type;
140
- // sycl::half may convert to _Float16, and we would try to instantiate
141
- // vec class with _Float16 DataType, which is not expected there. As
142
- // such, leave vector<half, N> as-is.
143
- using MatchingVec =
144
- vec<std::conditional_t <std::is_same_v<ElemTy, half>, ElemTy,
145
- decltype (convertToOpenCLType (
146
- std::declval<ElemTy>()))>,
147
- no_ref::size ()>;
140
+ using ConvertedElemTy =
141
+ decltype (convertToOpenCLType (std::declval<ElemTy>()));
142
+ static constexpr int NumElements = no_ref::size ();
148
143
#ifdef __SYCL_DEVICE_ONLY__
149
- return sycl::bit_cast<typename MatchingVec::vector_t >(x);
144
+ using vector_t =
145
+ std::conditional_t <NumElements == 1 , ConvertedElemTy,
146
+ ConvertedElemTy
147
+ __attribute__ ((ext_vector_type (NumElements)))>;
148
+ return sycl::bit_cast<vector_t >(x);
150
149
#else
150
+ // sycl::half may convert to _Float16, and we would try to
151
+ // instantiate vec class with _Float16 DataType, which is not
152
+ // expected there. As such, leave vector<half, N> as-is.
153
+ using MatchingVec = vec<std::conditional_t <std::is_same_v<ElemTy, half>,
154
+ ElemTy, ConvertedElemTy>,
155
+ NumElements>;
151
156
return x.template as <MatchingVec>();
152
157
#endif
153
158
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
@@ -203,7 +208,7 @@ template <typename To, typename From> auto convertFromOpenCLTypeFor(From &&x) {
203
208
if constexpr (is_vec_v<To_noref> && is_vec_v<From_noref>)
204
209
return x.template as <To_noref>();
205
210
else if constexpr (is_vec_v<To_noref> && is_ext_vector_v<From_noref>)
206
- return To_noref{ bit_cast<typename To_noref:: vector_t >(x)} ;
211
+ return bit_cast<To >(x);
207
212
else
208
213
return static_cast <To>(x);
209
214
}
0 commit comments