Skip to content

Commit 0d3eb05

Browse files
[NFC][SYCL] Drop scalar half/bfloat16 type traits (#15706)
`std::is_same_v` is much clearer when checking for a single specific type, and the helper in `sycl/vector.hpp` could include more than these two types to increase its readability too.
1 parent 3bd93c4 commit 0d3eb05

File tree

4 files changed

+20
-37
lines changed

4 files changed

+20
-37
lines changed

sycl/include/sycl/detail/generic_type_lists.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ using scalar_vector_bfloat16_list =
7070
using bfloat16_list =
7171
tl_append<scalar_bfloat16_list, vector_bfloat16_list, marray_bfloat16_list>;
7272

73-
using half_bfloat16_list = tl_append<scalar_half_list, scalar_bfloat16_list>;
74-
7573
using scalar_float_list = type_list<float>;
7674

7775
using vector_float_list =

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,6 @@ template <typename T>
3131
inline constexpr bool is_svgenfloatf_v =
3232
is_contained_v<T, gtl::scalar_vector_float_list>;
3333

34-
template <typename T>
35-
inline constexpr bool is_half_v = is_contained_v<T, gtl::scalar_half_list>;
36-
37-
template <typename T>
38-
inline constexpr bool is_bfloat16_v =
39-
is_contained_v<T, gtl::scalar_bfloat16_list>;
40-
41-
template <typename T>
42-
inline constexpr bool is_half_or_bf16_v =
43-
is_contained_v<T, gtl::half_bfloat16_list>;
44-
4534
template <typename T>
4635
inline constexpr bool is_svgenfloath_v =
4736
is_contained_v<T, gtl::scalar_vector_half_list>;
@@ -141,10 +130,11 @@ template <typename T> auto convertToOpenCLType(T &&x) {
141130
// sycl::half may convert to _Float16, and we would try to instantiate
142131
// vec class with _Float16 DataType, which is not expected there. As
143132
// such, leave vector<half, N> as-is.
144-
using MatchingVec = vec<std::conditional_t<is_half_v<ElemTy>, ElemTy,
145-
decltype(convertToOpenCLType(
146-
std::declval<ElemTy>()))>,
147-
no_ref::size()>;
133+
using MatchingVec =
134+
vec<std::conditional_t<std::is_same_v<ElemTy, half>, ElemTy,
135+
decltype(convertToOpenCLType(
136+
std::declval<ElemTy>()))>,
137+
no_ref::size()>;
148138
#ifdef __SYCL_DEVICE_ONLY__
149139
return sycl::bit_cast<typename MatchingVec::vector_t>(x);
150140
#else
@@ -160,11 +150,11 @@ template <typename T> auto convertToOpenCLType(T &&x) {
160150
fixed_width_unsigned<sizeof(no_ref)>>;
161151
static_assert(sizeof(OpenCLType) == sizeof(T));
162152
return static_cast<OpenCLType>(x);
163-
} else if constexpr (is_half_v<no_ref>) {
153+
} else if constexpr (std::is_same_v<no_ref, half>) {
164154
using OpenCLType = sycl::detail::half_impl::BIsRepresentationT;
165155
static_assert(sizeof(OpenCLType) == sizeof(T));
166156
return static_cast<OpenCLType>(x);
167-
} else if constexpr (is_bfloat16_v<no_ref>) {
157+
} else if constexpr (std::is_same_v<no_ref, ext::oneapi::bfloat16>) {
168158
// On host, don't interpret BF16 as uint16.
169159
#ifdef __SYCL_DEVICE_ONLY__
170160
using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT;

sycl/include/sycl/vector.hpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ struct ScalarConversionOperatorMixIn<Vec, T, N, std::enable_if_t<N == 1>> {
119119
operator T() const { return (*static_cast<const Vec *>(this))[0]; }
120120
};
121121

122+
template <typename T>
123+
inline constexpr bool is_fundamental_or_half_or_bfloat16 =
124+
std::is_fundamental_v<T> || std::is_same_v<std::remove_const_t<T>, half> ||
125+
std::is_same_v<std::remove_const_t<T>, ext::oneapi::bfloat16>;
126+
122127
} // namespace detail
123128

124129
///////////////////////// class sycl::vec /////////////////////////
@@ -288,10 +293,8 @@ class __SYCL_EBO vec
288293
// when NumElements == 1. The template prevents implicit conversion from
289294
// vec<_, 1> to DataT.
290295
template <typename Ty = DataT>
291-
typename std::enable_if_t<
292-
std::is_fundamental_v<Ty> ||
293-
detail::is_half_or_bf16_v<typename std::remove_const_t<Ty>>,
294-
vec &>
296+
typename std::enable_if_t<detail::is_fundamental_or_half_or_bfloat16<Ty>,
297+
vec &>
295298
operator=(const DataT &Rhs) {
296299
*this = vec{Rhs};
297300
return *this;
@@ -626,16 +629,14 @@ class SwizzleOp {
626629
1 != IdxNum && SwizzleOp::getNumElements() == IdxNum, T>;
627630

628631
template <typename T>
629-
using EnableIfScalarType = typename std::enable_if_t<
630-
std::is_convertible_v<DataT, T> &&
631-
(std::is_fundamental_v<T> ||
632-
detail::is_half_or_bf16_v<typename std::remove_const_t<T>>)>;
632+
using EnableIfScalarType =
633+
typename std::enable_if_t<std::is_convertible_v<DataT, T> &&
634+
detail::is_fundamental_or_half_or_bfloat16<T>>;
633635

634636
template <typename T>
635-
using EnableIfNoScalarType = typename std::enable_if_t<
636-
!std::is_convertible_v<DataT, T> ||
637-
!(std::is_fundamental_v<T> ||
638-
detail::is_half_or_bf16_v<typename std::remove_const_t<T>>)>;
637+
using EnableIfNoScalarType =
638+
typename std::enable_if_t<!std::is_convertible_v<DataT, T> ||
639+
!detail::is_fundamental_or_half_or_bfloat16<T>>;
639640

640641
template <int... Indices>
641642
using Swizzle =

sycl/test/basic_tests/generic_type_traits.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@ int main() {
2323
static_assert(d::is_genfloat_v<s::opencl::cl_float> == true);
2424
static_assert(d::is_genfloat_v<s::vec<s::opencl::cl_float, 4>> == true);
2525

26-
static_assert(d::is_half_v<s::half>);
27-
28-
static_assert(d::is_bfloat16_v<sycl::ext::oneapi::bfloat16>);
29-
static_assert(d::is_half_or_bf16_v<s::half>);
30-
static_assert(d::is_half_or_bf16_v<sycl::ext::oneapi::bfloat16>);
31-
3226
// TODO add checks for the following type traits
3327
/*
3428
is_doublen

0 commit comments

Comments
 (0)