Skip to content

Commit 2c39d67

Browse files
authored
[SYCL][COMPAT] Add vectorized_ternary and vectorized_with_pred. (#15550)
Add vectorized_ternary and vectorized_with_pred. Update relu and vectorized_binary. Signed-off-by: Tang, Jiajun jiajun.tang@intel.com --------- Signed-off-by: Tang, Jiajun jiajun.tang@intel.com
1 parent 4cc64bd commit 2c39d67

File tree

5 files changed

+333
-63
lines changed

5 files changed

+333
-63
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,17 +1906,11 @@ template <typename ValueT, typename ValueU>
19061906
inline typename std::enable_if_t<!std::is_floating_point_v<ValueT>, double>
19071907
pow(const ValueT a, const ValueU b);
19081908
1909-
template <typename ValueT>
1910-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1911-
std::is_same_v<sycl::half, ValueT>,
1912-
ValueT>
1913-
relu(const ValueT a);
1909+
template <typename ValueT> inline ValueT relu(const ValueT a);
19141910
1915-
template <class ValueT>
1916-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1917-
std::is_same_v<sycl::half, ValueT>,
1918-
sycl::vec<ValueT, 2>>
1919-
relu(const sycl::vec<ValueT, 2> a);
1911+
template <class ValueT, int NumElements>
1912+
inline sycl::vec<ValueT, NumElements>
1913+
relu(const sycl::vec<ValueT, NumElements> a);
19201914
19211915
template <class ValueT>
19221916
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
@@ -2019,9 +2013,12 @@ inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,
20192013
20202014
`vectorized_binary` computes the `BinaryOperation` for two operands,
20212015
with each value treated as a vector type. `vectorized_unary` offers the same
2022-
interface for operations with a single operand.
2016+
interface for operations with a single operand. `vectorized_ternary` offers the
2017+
interface for three operands with two `BinaryOperation`.
20232018
The implemented `BinaryOperation`s are `abs_diff`, `add_sat`, `rhadd`, `hadd`,
20242019
`maximum`, `minimum`, and `sub_sat`.
2020+
And the `vectorized_with_pred` offers the `BinaryOperation` for two operands,
2021+
meanwihle provides the pred of high/low halfword operation.
20252022
20262023
```cpp
20272024
namespace syclcompat {
@@ -2036,7 +2033,19 @@ struct abs {
20362033
20372034
template <typename VecT, class BinaryOperation>
20382035
inline unsigned vectorized_binary(unsigned a, unsigned b,
2039-
const BinaryOperation binary_op);
2036+
const BinaryOperation binary_op,
2037+
bool need_relu = false);
2038+
2039+
template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
2040+
inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
2041+
const BinaryOperation1 binary_op1,
2042+
const BinaryOperation2 binary_op2,
2043+
bool need_relu = false);
2044+
2045+
template <typename ValueT, typename BinaryOperation>
2046+
inline unsigned vectorized_with_pred(unsigned a, unsigned b,
2047+
const BinaryOperation binary_op,
2048+
bool *pred_hi, bool *pred_lo);
20402049
20412050
// A sycl::abs_diff wrapper functor.
20422051
struct abs_diff {
@@ -2062,11 +2071,15 @@ struct hadd {
20622071
struct maximum {
20632072
template <typename ValueT>
20642073
auto operator()(const ValueT x, const ValueT y) const;
2074+
template <typename ValueT>
2075+
auto operator()(const ValueT x, const ValueT y, bool *pred) const;
20652076
};
20662077
// A sycl::min wrapper functor.
20672078
struct minimum {
20682079
template <typename ValueT>
20692080
auto operator()(const ValueT x, const ValueT y) const;
2081+
template <typename ValueT>
2082+
auto operator()(const ValueT x, const ValueT y, bool *pred) const;
20702083
};
20712084
// A sycl::sub_sat wrapper functor.
20722085
struct sub_sat {

sycl/include/syclcompat/math.hpp

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -856,23 +856,24 @@ pow(const ValueT a, const ValueU b) {
856856
/// Performs relu saturation.
857857
/// \param [in] a The input value
858858
/// \returns the relu saturation result
859-
template <typename ValueT>
860-
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
861-
relu(const ValueT a) {
862-
if (!detail::isnan(a) && a < ValueT(0))
859+
template <typename ValueT> inline ValueT relu(const ValueT a) {
860+
if constexpr (syclcompat::is_floating_point_v<ValueT>)
861+
if (detail::isnan(a))
862+
return a;
863+
if (a < ValueT(0))
863864
return ValueT(0);
864865
return a;
865866
}
866-
template <class ValueT>
867-
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>,
868-
sycl::vec<ValueT, 2>>
869-
relu(const sycl::vec<ValueT, 2> a) {
870-
return {relu(a[0]), relu(a[1])};
867+
template <class ValueT, int NumElements>
868+
inline sycl::vec<ValueT, NumElements>
869+
relu(const sycl::vec<ValueT, NumElements> a) {
870+
sycl::vec<ValueT, NumElements> ret;
871+
for (int i = 0; i < NumElements; ++i)
872+
ret[i] = relu(a[i]);
873+
return ret;
871874
}
872875
template <class ValueT>
873-
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>,
874-
sycl::marray<ValueT, 2>>
875-
relu(const sycl::marray<ValueT, 2> a) {
876+
inline sycl::marray<ValueT, 2> relu(const sycl::marray<ValueT, 2> a) {
876877
return {relu(a[0]), relu(a[1])};
877878
}
878879

@@ -990,6 +991,10 @@ struct maximum {
990991
auto operator()(const ValueT x, const ValueT y) const {
991992
return sycl::max(x, y);
992993
}
994+
template <typename ValueT>
995+
auto operator()(const ValueT x, const ValueT y, bool *pred) const {
996+
return (x >= y) ? ((*pred = true), x) : ((*pred = false), y);
997+
}
993998
};
994999

9951000
/// A sycl::min wrapper functors.
@@ -998,6 +1003,10 @@ struct minimum {
9981003
auto operator()(const ValueT x, const ValueT y) const {
9991004
return sycl::min(x, y);
10001005
}
1006+
template <typename ValueT>
1007+
auto operator()(const ValueT x, const ValueT y, bool *pred) const {
1008+
return (x <= y) ? ((*pred = true), x) : ((*pred = false), y);
1009+
}
10011010
};
10021011

10031012
/// A sycl::sub_sat wrapper functors.
@@ -1037,19 +1046,76 @@ struct average {
10371046
/// \tparam [in] BinaryOperation The binary operation class
10381047
/// \param [in] a The first value
10391048
/// \param [in] b The second value
1049+
/// \param [in] binary_op The operation to do with the two values
1050+
/// \param [in] need_relu Whether the result need relu saturation
10401051
/// \returns The vectorized binary operation value of the two values
10411052
template <typename VecT, class BinaryOperation>
10421053
inline unsigned vectorized_binary(unsigned a, unsigned b,
1043-
const BinaryOperation binary_op) {
1054+
const BinaryOperation binary_op,
1055+
bool need_relu = false) {
10441056
sycl::vec<unsigned, 1> v0{a}, v1{b};
10451057
auto v2 = v0.as<VecT>();
10461058
auto v3 = v1.as<VecT>();
10471059
auto v4 =
10481060
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1061+
if (need_relu)
1062+
v4 = relu(v4);
10491063
v0 = v4.template as<sycl::vec<unsigned, 1>>();
10501064
return v0;
10511065
}
10521066

1067+
/// Compute two vectorized binary operation value with pred for three values,
1068+
/// with each value treated as a 2 \p T type elements vector type.
1069+
///
1070+
/// \tparam [in] VecT The type of the vector
1071+
/// \tparam [in] BinaryOperation1 The first binary operation class
1072+
/// \tparam [in] BinaryOperation2 The second binary operation class
1073+
/// \param [in] a The first value
1074+
/// \param [in] b The second value
1075+
/// \param [in] c The third value
1076+
/// \param [in] binary_op1 The first operation to do with the first two values
1077+
/// \param [in] binary_op2 The second operation to do with the third values
1078+
/// \param [in] need_relu Whether the result need relu saturation
1079+
/// \returns The two vectorized binary operation value of the three values
1080+
template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
1081+
inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
1082+
const BinaryOperation1 binary_op1,
1083+
const BinaryOperation2 binary_op2,
1084+
bool need_relu = false) {
1085+
const auto v1 = sycl::vec<unsigned, 1>(a).as<VecT>();
1086+
const auto v2 = sycl::vec<unsigned, 1>(b).as<VecT>();
1087+
const auto v3 = sycl::vec<unsigned, 1>(c).as<VecT>();
1088+
auto v4 =
1089+
detail::vectorized_binary<VecT, BinaryOperation1>()(v1, v2, binary_op1);
1090+
v4 = detail::vectorized_binary<VecT, BinaryOperation2>()(v4, v3, binary_op2);
1091+
if (need_relu)
1092+
v4 = relu(v4);
1093+
return v4.template as<sycl::vec<unsigned, 1>>();
1094+
}
1095+
1096+
/// Compute vectorized binary operation value with pred for two values, with
1097+
/// each value treated as a 2 \p T type elements vector type.
1098+
///
1099+
/// \tparam [in] VecT The type of the vector
1100+
/// \tparam [in] BinaryOperation The binary operation class
1101+
/// \param [in] a The first value
1102+
/// \param [in] b The second value
1103+
/// \param [in] binary_op The operation with pred to do with the two values
1104+
/// \param [out] pred_hi The pred pointer that pass into high halfword operation
1105+
/// \param [out] pred_lo The pred pointer that pass into low halfword operation
1106+
/// \returns The vectorized binary operation value of the two values
1107+
template <typename VecT, typename BinaryOperation>
1108+
inline unsigned vectorized_binary_with_pred(unsigned a, unsigned b,
1109+
const BinaryOperation binary_op,
1110+
bool *pred_hi, bool *pred_lo) {
1111+
auto v1 = sycl::vec<unsigned, 1>(a).as<VecT>();
1112+
auto v2 = sycl::vec<unsigned, 1>(b).as<VecT>();
1113+
VecT ret;
1114+
ret[0] = binary_op(v1[0], v2[0], pred_lo);
1115+
ret[1] = binary_op(v1[1], v2[1], pred_hi);
1116+
return ret.template as<sycl::vec<unsigned, 1>>();
1117+
}
1118+
10531119
template <typename T1, typename T2>
10541120
using dot_product_acc_t =
10551121
std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,

sycl/test-e2e/syclcompat/math/math_fixt.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
134134
ValueT *op1_;
135135
ValueU *op2_;
136136
ResultT res_h_, *res_;
137+
bool *res_hi_;
138+
bool *res_lo_;
137139

138140
public:
139141
BinaryOpTestLauncher(const syclcompat::dim3 &grid,
@@ -147,6 +149,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
147149
op1_ = syclcompat::malloc<ValueT>(data_size);
148150
op2_ = syclcompat::malloc<ValueU>(data_size);
149151
res_ = syclcompat::malloc<ResultT>(data_size);
152+
res_hi_ = syclcompat::malloc<bool>(1);
153+
res_lo_ = syclcompat::malloc<bool>(1);
150154
};
151155

152156
virtual ~BinaryOpTestLauncher() {
@@ -155,6 +159,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
155159
syclcompat::free(op1_);
156160
syclcompat::free(op2_);
157161
syclcompat::free(res_);
162+
syclcompat::free(res_hi_);
163+
syclcompat::free(res_lo_);
158164
}
159165

160166
template <auto Kernel>
@@ -169,6 +175,37 @@ class BinaryOpTestLauncher : OpTestLauncher {
169175

170176
CHECK(ResultT, res_h_, expected);
171177
};
178+
template <auto Kernel>
179+
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool need_relu) {
180+
if (skip_)
181+
return;
182+
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
183+
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
184+
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, need_relu);
185+
syclcompat::wait();
186+
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);
187+
188+
CHECK(ResultT, res_h_, expected);
189+
};
190+
template <auto Kernel>
191+
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool expected_hi,
192+
bool expected_lo) {
193+
if (skip_)
194+
return;
195+
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
196+
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
197+
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, res_hi_,
198+
res_lo_);
199+
syclcompat::wait();
200+
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);
201+
bool res_hi_h_, res_lo_h_;
202+
syclcompat::memcpy<bool>(&res_hi_h_, res_hi_, 1);
203+
syclcompat::memcpy<bool>(&res_lo_h_, res_lo_, 1);
204+
205+
CHECK(ResultT, res_h_, expected);
206+
assert(res_hi_h_ == expected_hi);
207+
assert(res_lo_h_ == expected_lo);
208+
};
172209
};
173210

174211
template <typename ValueT, typename ResultT = ValueT>
@@ -208,3 +245,54 @@ class UnaryOpTestLauncher : OpTestLauncher {
208245
CHECK(ResultT, res_h_, expected);
209246
}
210247
};
248+
249+
// Templated ResultT to support both arithmetic and boolean operators
250+
template <typename ValueT, typename ValueU, typename ValueV,
251+
typename ResultT = std::common_type_t<ValueT, ValueU, ValueV>>
252+
class TernaryOpTestLauncher : OpTestLauncher {
253+
protected:
254+
ValueT *op1_;
255+
ValueU *op2_;
256+
ValueV *op3_;
257+
ResultT res_h_, *res_;
258+
259+
public:
260+
TernaryOpTestLauncher(const syclcompat::dim3 &grid,
261+
const syclcompat::dim3 &threads,
262+
const size_t data_size = 1)
263+
: OpTestLauncher{grid, threads, data_size,
264+
should_skip<ValueT, ValueU, ValueV, ResultT>()(
265+
syclcompat::get_current_device())} {
266+
if (skip_)
267+
return;
268+
op1_ = syclcompat::malloc<ValueT>(data_size);
269+
op2_ = syclcompat::malloc<ValueU>(data_size);
270+
op3_ = syclcompat::malloc<ValueV>(data_size);
271+
res_ = syclcompat::malloc<ResultT>(data_size);
272+
};
273+
274+
virtual ~TernaryOpTestLauncher() {
275+
if (skip_)
276+
return;
277+
syclcompat::free(op1_);
278+
syclcompat::free(op2_);
279+
syclcompat::free(op3_);
280+
syclcompat::free(res_);
281+
}
282+
283+
template <auto Kernel>
284+
void launch_test(ValueT op1, ValueU op2, ValueV op3, ResultT expected,
285+
bool need_relu = false) {
286+
if (skip_)
287+
return;
288+
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
289+
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
290+
syclcompat::memcpy<ValueV>(op3_, &op3, data_size_);
291+
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_,
292+
need_relu);
293+
syclcompat::wait();
294+
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);
295+
296+
CHECK(ResultT, res_h_, expected);
297+
};
298+
};

sycl/test-e2e/syclcompat/math/math_ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,10 @@ template <typename ValueT> void test_syclcompat_relu() {
226226
UnaryOpTestLauncher<ValueT>(grid, threads)
227227
.template launch_test<relu_kernel<ValueT>>(op1, res1);
228228

229-
const ValueT op2 = static_cast<ValueT>(-3);
230-
const ValueT res2 = static_cast<ValueT>(0);
229+
const ValueT op2 = std::is_signed_v<ValueT> ? static_cast<ValueT>(-3)
230+
: static_cast<ValueT>(2);
231+
const ValueT res2 = std::is_signed_v<ValueT> ? static_cast<ValueT>(0)
232+
: static_cast<ValueT>(2);
231233
UnaryOpTestLauncher<ValueT>(grid, threads)
232234
.template launch_test<relu_kernel<ValueT>>(op2, res2);
233235

@@ -374,7 +376,7 @@ int main() {
374376
test_syclcompat_pow<float, int>();
375377
test_syclcompat_pow<double, int>();
376378

377-
INSTANTIATE_ALL_TYPES(fp_type_list, test_syclcompat_relu);
379+
INSTANTIATE_ALL_TYPES(value_type_list, test_syclcompat_relu);
378380
INSTANTIATE_ALL_TYPES(fp_type_list_no_bfloat16, test_syclcompat_cbrt);
379381

380382
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_isnan);

0 commit comments

Comments
 (0)