Skip to content

Commit 462d6d4

Browse files
[SYCL] Fix operators for bool swizzled vec (#12001)
c3a9615 added `CommonDataT` to fix an issue where results were truncated if an operand had a larger type than the result. As `CommonDataT` uses `std::common_type_t` of the left and right operands, it ignores `DataT`. This causes an issue for cases where `CommonDataT = bool` and `DataT != bool` since the operation result will be implicitly converted from `DataT` to `bool`. This change adds `DataT` to set of types used to define `CommonDataT`. Cases where the operands were `bool` type will now have the correct result rather than `1`. Fixes #11995. --------- Signed-off-by: Michael Aziz <michael.aziz@intel.com> Co-authored-by: aelovikov-intel <andrei.elovikov@intel.com>
1 parent ef6f8f4 commit 462d6d4

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

sycl/include/sycl/types.hpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,8 +1544,37 @@ template <typename VecT, typename OperationLeftT, typename OperationRightT,
15441544
template <typename> class OperationCurrentT, int... Indexes>
15451545
class SwizzleOp {
15461546
using DataT = typename VecT::element_type;
1547-
using CommonDataT = std::common_type_t<typename OperationLeftT::DataT,
1548-
typename OperationRightT::DataT>;
1547+
// Certain operators return a vector with a different element type. Also, the
1548+
// left and right operand types may differ. CommonDataT selects a result type
1549+
// based on these types to ensure that the result value can be represented.
1550+
//
1551+
// Example 1:
1552+
// sycl::vec<unsigned char, 4> vec{...};
1553+
// auto result = 300u + vec.x();
1554+
//
1555+
// CommonDataT is std::common_type_t<OperationLeftT, OperationRightT> since
1556+
// it's larger than unsigned char.
1557+
//
1558+
// Example 2:
1559+
// sycl::vec<bool, 1> vec{...};
1560+
// auto result = vec.template swizzle<sycl::elem::s0>() && vec;
1561+
//
1562+
// CommonDataT is DataT since operator&& returns a vector with element type
1563+
// int8_t, which is larger than bool.
1564+
//
1565+
// Example 3:
1566+
// sycl::vec<std::byte, 4> vec{...}; auto swlo = vec.lo();
1567+
// auto result = swlo == swlo;
1568+
//
1569+
// CommonDataT is DataT since operator== returns a vector with element type
1570+
// int8_t, which is the same size as std::byte. std::common_type_t<DataT, ...>
1571+
// can't be used here since there's no type that int8_t and std::byte can both
1572+
// be implicitly converted to.
1573+
using OpLeftDataT = typename OperationLeftT::DataT;
1574+
using OpRightDataT = typename OperationRightT::DataT;
1575+
using CommonDataT = std::conditional_t<
1576+
sizeof(DataT) >= sizeof(std::common_type_t<OpLeftDataT, OpRightDataT>),
1577+
DataT, std::common_type_t<OpLeftDataT, OpRightDataT>>;
15491578
static constexpr int getNumElements() { return sizeof...(Indexes); }
15501579

15511580
using rel_t = detail::rel_t<DataT>;
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// RUN: %if preview-breaking-changes-supported %{ %clangxx -fsycl -fpreview-breaking-changes %s -o %t2.out %}
5+
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}
6+
7+
#include <cstdlib>
8+
#include <sycl/sycl.hpp>
9+
10+
template <typename T, typename ResultT>
11+
bool testAndOperator(const std::string &typeName) {
12+
constexpr int N = 5;
13+
std::array<ResultT, N> results{};
14+
15+
sycl::queue q;
16+
sycl::buffer<ResultT, 1> buffer{results.data(), N};
17+
q.submit([&](sycl::handler &cgh) {
18+
sycl::accessor acc{buffer, cgh, sycl::write_only};
19+
cgh.parallel_for(sycl::range<1>{1}, [=](sycl::id<1> id) {
20+
auto testVec1 = sycl::vec<T, 1>(static_cast<T>(1));
21+
auto testVec2 = sycl::vec<T, 1>(static_cast<T>(2));
22+
sycl::vec<ResultT, 1> resVec;
23+
24+
ResultT expected = static_cast<ResultT>(
25+
-(static_cast<ResultT>(1) && static_cast<ResultT>(2)));
26+
acc[0] = expected;
27+
28+
// LHS swizzle
29+
resVec = testVec1.template swizzle<sycl::elem::s0>() && testVec2;
30+
acc[1] = resVec[0];
31+
32+
// RHS swizzle
33+
resVec = testVec1 && testVec2.template swizzle<sycl::elem::s0>();
34+
acc[2] = resVec[0];
35+
36+
// No swizzle
37+
resVec = testVec1 && testVec2;
38+
acc[3] = resVec[0];
39+
40+
// Both swizzle
41+
resVec = testVec1.template swizzle<sycl::elem::s0>() &&
42+
testVec2.template swizzle<sycl::elem::s0>();
43+
acc[4] = resVec[0];
44+
});
45+
}).wait();
46+
47+
bool passed = true;
48+
ResultT expected = results[0];
49+
50+
std::cout << "Testing with T = " << typeName << std::endl;
51+
std::cout << "Expected: " << (int)expected << std::endl;
52+
for (int i = 1; i < N; i++) {
53+
std::cout << "Test " << (i - 1) << ": " << ((int)results[i]) << std::endl;
54+
passed &= expected == results[i];
55+
}
56+
std::cout << std::endl;
57+
return passed;
58+
}
59+
60+
int main() {
61+
bool passed = true;
62+
passed &= testAndOperator<bool, std::int8_t>("bool");
63+
passed &= testAndOperator<std::int8_t, std::int8_t>("std::int8_t");
64+
passed &= testAndOperator<float, std::int32_t>("float");
65+
passed &= testAndOperator<int, std::int32_t>("int");
66+
std::cout << (passed ? "Pass" : "Fail") << std::endl;
67+
return (passed ? EXIT_SUCCESS : EXIT_FAILURE);
68+
}

0 commit comments

Comments
 (0)