Skip to content

Commit 436191a

Browse files
[SYCL] Fix 1-element vec ambiguities (#17722)
Implements KhronosGroup/SYCL-Docs#670. Technically, we also implement part of KhronosGroup/SYCL-Docs#674 (`std::byte` as element type) here, but there is no reasonable way to make them completely independent. This is built on top of #17712 and #17713.
1 parent 9e38e3a commit 436191a

File tree

6 files changed

+139
-20
lines changed

6 files changed

+139
-20
lines changed

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ struct IncDec {};
6666

6767
template <class T> static constexpr bool not_fp = !is_vgenfloat_v<T>;
6868

69+
#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
70+
// Not using `is_byte_v` to avoid unnecessary dependencies on `half`/`bfloat16`
71+
// headers.
72+
template <class T>
73+
static constexpr bool not_byte =
74+
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
75+
!std::is_same_v<T, std::byte>;
76+
#else
77+
true;
78+
#endif
79+
#endif
80+
6981
// To provide information about operators availability depending on vec/swizzle
7082
// element type.
7183
template <typename Op, typename T>
@@ -80,6 +92,7 @@ inline constexpr bool is_op_available_for_type<OpAssign<Op>, T> =
8092
inline constexpr bool is_op_available_for_type<OP, T> = COND;
8193

8294
// clang-format off
95+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
8396
__SYCL_OP_AVAILABILITY(std::plus<void> , true)
8497
__SYCL_OP_AVAILABILITY(std::minus<void> , true)
8598
__SYCL_OP_AVAILABILITY(std::multiplies<void> , true)
@@ -110,6 +123,38 @@ __SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
110123
__SYCL_OP_AVAILABILITY(UnaryPlus , true)
111124

112125
__SYCL_OP_AVAILABILITY(IncDec , true)
126+
#else
127+
__SYCL_OP_AVAILABILITY(std::plus<void> , not_byte<T>)
128+
__SYCL_OP_AVAILABILITY(std::minus<void> , not_byte<T>)
129+
__SYCL_OP_AVAILABILITY(std::multiplies<void> , not_byte<T>)
130+
__SYCL_OP_AVAILABILITY(std::divides<void> , not_byte<T>)
131+
__SYCL_OP_AVAILABILITY(std::modulus<void> , not_fp<T>)
132+
133+
__SYCL_OP_AVAILABILITY(std::bit_and<void> , not_fp<T>)
134+
__SYCL_OP_AVAILABILITY(std::bit_or<void> , not_fp<T>)
135+
__SYCL_OP_AVAILABILITY(std::bit_xor<void> , not_fp<T>)
136+
137+
__SYCL_OP_AVAILABILITY(std::equal_to<void> , true)
138+
__SYCL_OP_AVAILABILITY(std::not_equal_to<void> , true)
139+
__SYCL_OP_AVAILABILITY(std::less<void> , true)
140+
__SYCL_OP_AVAILABILITY(std::greater<void> , true)
141+
__SYCL_OP_AVAILABILITY(std::less_equal<void> , true)
142+
__SYCL_OP_AVAILABILITY(std::greater_equal<void> , true)
143+
144+
__SYCL_OP_AVAILABILITY(std::logical_and<void> , not_byte<T> && not_fp<T>)
145+
__SYCL_OP_AVAILABILITY(std::logical_or<void> , not_byte<T> && not_fp<T>)
146+
147+
__SYCL_OP_AVAILABILITY(ShiftLeft , not_byte<T> && not_fp<T>)
148+
__SYCL_OP_AVAILABILITY(ShiftRight , not_byte<T> && not_fp<T>)
149+
150+
// Unary
151+
__SYCL_OP_AVAILABILITY(std::negate<void> , not_byte<T>)
152+
__SYCL_OP_AVAILABILITY(std::logical_not<void> , not_byte<T>)
153+
__SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
154+
__SYCL_OP_AVAILABILITY(UnaryPlus , not_byte<T>)
155+
156+
__SYCL_OP_AVAILABILITY(IncDec , not_byte<T>)
157+
#endif
113158
// clang-format on
114159

115160
#undef __SYCL_OP_AVAILABILITY
@@ -188,6 +233,12 @@ template <typename Self> struct VecOperators {
188233
using element_type = typename from_incomplete<Self>::element_type;
189234
static constexpr int N = from_incomplete<Self>::size();
190235

236+
#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
237+
template <typename T>
238+
static constexpr bool is_compatible_scalar =
239+
std::is_convertible_v<T, typename from_incomplete<Self>::element_type>;
240+
#endif
241+
191242
template <typename Op>
192243
using result_t = std::conditional_t<
193244
is_logical<Op>, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
@@ -293,6 +344,7 @@ template <typename Self> struct VecOperators {
293344
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, IncDec>>>
294345
: public IncDecImpl<Self> {};
295346

347+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
296348
#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
297349
template <typename Op> \
298350
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
@@ -341,13 +393,60 @@ template <typename Self> struct VecOperators {
341393
friend auto operator OPERATOR(const Self &v) { return apply<OP>(v); } \
342394
};
343395

396+
#else
397+
398+
#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
399+
template <typename Op> \
400+
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
401+
friend result_t<OP> operator OPERATOR(const Self & lhs, \
402+
const Self & rhs) { \
403+
return VecOperators::apply<OP>(lhs, rhs); \
404+
} \
405+
template <typename T> \
406+
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
407+
operator OPERATOR(const Self & lhs, const T & rhs) { \
408+
return VecOperators::apply<OP>(lhs, Self{static_cast<T>(rhs)}); \
409+
} \
410+
template <typename T> \
411+
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
412+
operator OPERATOR(const T & lhs, const Self & rhs) { \
413+
return VecOperators::apply<OP>(Self{static_cast<T>(lhs)}, rhs); \
414+
} \
415+
};
416+
417+
#define __SYCL_VEC_OPASSIGN_MIXIN(OP, OPERATOR) \
418+
template <typename Op> \
419+
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OpAssign<OP>>>> { \
420+
friend Self &operator OPERATOR(Self & lhs, const Self & rhs) { \
421+
lhs = OP{}(lhs, rhs); \
422+
return lhs; \
423+
} \
424+
template <typename T> \
425+
friend std::enable_if_t<is_compatible_scalar<T>, Self &> \
426+
operator OPERATOR(Self & lhs, const T & rhs) { \
427+
lhs = OP{}(lhs, rhs); \
428+
return lhs; \
429+
} \
430+
};
431+
432+
#define __SYCL_VEC_UOP_MIXIN(OP, OPERATOR) \
433+
template <typename Op> \
434+
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
435+
friend result_t<OP> operator OPERATOR(const Self & v) { \
436+
return apply<OP>(v); \
437+
} \
438+
};
439+
440+
#endif
441+
344442
__SYCL_INSTANTIATE_OPERATORS(__SYCL_VEC_BINOP_MIXIN,
345443
__SYCL_VEC_OPASSIGN_MIXIN, __SYCL_VEC_UOP_MIXIN)
346444

347445
#undef __SYCL_VEC_UOP_MIXIN
348446
#undef __SYCL_VEC_OPASSIGN_MIXIN
349447
#undef __SYCL_VEC_BINOP_MIXIN
350448

449+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
351450
template <typename Op>
352451
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, std::bit_not<void>>>> {
353452
template <typename T = typename from_incomplete<Self>::element_type>
@@ -356,6 +455,7 @@ template <typename Self> struct VecOperators {
356455
return apply<std::bit_not<void>>(v);
357456
}
358457
};
458+
#endif
359459

360460
template <typename... Op>
361461
struct __SYCL_EBO CombineImpl : public OpMixin<Op>... {};
@@ -377,6 +477,7 @@ template <typename Self> struct VecOperators {
377477
OpAssign<ShiftRight>, IncDec> {};
378478
};
379479

480+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
380481
template <typename DataT, int NumElements>
381482
class vec_arith : public VecOperators<vec<DataT, NumElements>>::Combined {};
382483

@@ -427,6 +528,7 @@ class vec_arith<std::byte, NumElements>
427528
}
428529
};
429530
#endif // (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
531+
#endif
430532

431533
#undef __SYCL_INSTANTIATE_OPERATORS
432534

sycl/include/sycl/vector.hpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,18 @@ template <typename DataT> class vec_base<DataT, 1> {
318318
// Provides a cross-platform vector class template that works efficiently on
319319
// SYCL devices as well as in host C++ code.
320320
template <typename DataT, int NumElements>
321-
class __SYCL_EBO vec
322-
: public detail::vec_arith<DataT, NumElements>,
323-
public detail::ApplyIf<
324-
NumElements == 1,
325-
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
326-
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
327-
// Keep it last to simplify ABI layout test:
328-
public detail::vec_base<DataT, NumElements> {
321+
class __SYCL_EBO vec :
322+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
323+
public detail::vec_arith<DataT, NumElements>,
324+
#else
325+
public detail::VecOperators<vec<DataT, NumElements>>::Combined,
326+
#endif
327+
public detail::ApplyIf<
328+
NumElements == 1,
329+
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
330+
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
331+
// Keep it last to simplify ABI layout test:
332+
public detail::vec_base<DataT, NumElements> {
329333
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
330334
"DataT must be cv-unqualified");
331335

@@ -408,6 +412,7 @@ class __SYCL_EBO vec
408412
constexpr vec &operator=(const vec &) = default;
409413
constexpr vec &operator=(vec &&) = default;
410414

415+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
411416
// Template required to prevent ambiguous overload with the copy assignment
412417
// when NumElements == 1. The template prevents implicit conversion from
413418
// vec<_, 1> to DataT.
@@ -427,6 +432,14 @@ class __SYCL_EBO vec
427432
*this = Rhs.template as<vec>();
428433
return *this;
429434
}
435+
#else
436+
template <typename T>
437+
typename std::enable_if_t<std::is_convertible_v<T, DataT>, vec &>
438+
operator=(const T &Rhs) {
439+
*this = vec{static_cast<DataT>(Rhs)};
440+
return *this;
441+
}
442+
#endif
430443

431444
__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
432445
static constexpr size_t get_count() { return size(); }
@@ -536,8 +549,10 @@ class __SYCL_EBO vec
536549
int... T5>
537550
friend class detail::SwizzleOp;
538551
template <typename T1, int T2> friend class __SYCL_EBO vec;
552+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
539553
// To allow arithmetic operators access private members of vec.
540554
template <typename T1, int T2> friend class detail::vec_arith;
555+
#endif
541556
};
542557
///////////////////////// class sycl::vec /////////////////////////
543558

sycl/test-e2e/Basic/vector/byte.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ int main() {
180180
assert(SwizByte2Neg[0] == ~SwizByte2B[0]);
181181
}
182182

183+
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
183184
{
184185
// std::byte is not an arithmetic type and it only supports the following
185186
// overloads of >> and << operators.
@@ -207,6 +208,7 @@ int main() {
207208
assert(SwizShiftRight[0] == SwizByte2Shift[0] >> 3 &&
208209
SwizShiftLeft[1] == SwizByte2Shift[1] << 3);
209210
}
211+
#endif
210212
}
211213

212214
return 0;

sycl/test-e2e/Basic/vector/vec_binary_scalar_order.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ bool CheckResult(sycl::vec<T1, N> V, T2 Ref) {
3838
constexpr T RefVal = 2; \
3939
VecT InVec{static_cast<T>(RefVal)}; \
4040
{ \
41-
VecT OutVecsDevice[2]; \
41+
ResT OutVecsDevice[2]; \
4242
T OutRefsDevice[2]; \
4343
{ \
44-
sycl::buffer<VecT, 1> OutVecsBuff{OutVecsDevice, 2}; \
44+
sycl::buffer<ResT, 1> OutVecsBuff{OutVecsDevice, 2}; \
4545
sycl::buffer<T, 1> OutRefsBuff{OutRefsDevice, 2}; \
4646
Q.submit([&](sycl::handler &CGH) { \
4747
sycl::accessor OutVecsAcc{OutVecsBuff, CGH, sycl::read_write}; \

sycl/test-e2e/DeviceLib/built-ins/vector_integer.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ int main() {
203203

204204
// abs
205205
{
206-
s::uint2 r{0};
206+
s::int2 r{0};
207207
{
208-
s::buffer<s::uint2, 1> BufR(&r, s::range<1>(1));
208+
s::buffer<s::int2, 1> BufR(&r, s::range<1>(1));
209209
s::queue myQueue;
210210
myQueue.submit([&](s::handler &cgh) {
211211
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
@@ -214,8 +214,8 @@ int main() {
214214
});
215215
});
216216
}
217-
unsigned int r1 = r.x();
218-
unsigned int r2 = r.y();
217+
int r1 = r.x();
218+
int r2 = r.y();
219219
assert(r1 == 5);
220220
assert(r2 == 2);
221221
}
@@ -240,9 +240,9 @@ int main() {
240240

241241
// abs_diff
242242
{
243-
s::uint2 r{0};
243+
s::int2 r{0};
244244
{
245-
s::buffer<s::uint2, 1> BufR(&r, s::range<1>(1));
245+
s::buffer<s::int2, 1> BufR(&r, s::range<1>(1));
246246
s::queue myQueue;
247247
myQueue.submit([&](s::handler &cgh) {
248248
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
@@ -251,8 +251,8 @@ int main() {
251251
});
252252
});
253253
}
254-
unsigned int r1 = r.x();
255-
unsigned int r2 = r.y();
254+
int r1 = r.x();
255+
int r2 = r.y();
256256
assert(r1 == 4);
257257
assert(r2 == 1);
258258
}

sycl/test/basic_tests/vectors/assign.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ using sw_double_2 = decltype(std::declval<vec<double, 4>>().swizzle<1, 2>());
2727
// EXCEPT_IN_PREVIEW condition<>
2828

2929
static_assert( std::is_assignable_v<vec<half, 1>, half>);
30-
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, float>);
31-
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, double>);
30+
static_assert( std::is_assignable_v<vec<half, 1>, float>);
31+
static_assert( std::is_assignable_v<vec<half, 1>, double>);
3232
static_assert( std::is_assignable_v<vec<half, 1>, vec<half, 1>>);
3333
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, vec<float, 1>>);
3434
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, vec<double, 1>>);

0 commit comments

Comments
 (0)