Skip to content

Commit c603a7f

Browse files
[NFCI] Move sycl::vec::convert definition to vector_convert.hpp (#15160)
1) It keeps all `convert`-related code in one place 2) It allows not to include that extra functionality if this particular method isn't used 3) Potentially, it moves all instances where `half`/`bfloat16` needs to be "complete" in `vector.hpp`. Can't really verify that because `generic_type_traits.hpp` includes definitions for both at the moment
1 parent de43757 commit c603a7f

File tree

3 files changed

+103
-96
lines changed

3 files changed

+103
-96
lines changed

sycl/include/sycl/detail/vector_convert.hpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include <sycl/exception.hpp> // for errc
5959

6060
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
61+
#include <sycl/vector.hpp>
6162

6263
#ifndef __SYCL_DEVICE_ONLY__
6364
#include <cfenv> // for fesetround, fegetround
@@ -153,8 +154,6 @@ __imf_ushort_as_bfloat16(unsigned short x);
153154

154155
namespace sycl {
155156

156-
enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };
157-
158157
inline namespace _V1 {
159158
#ifndef __SYCL_DEVICE_ONLY__
160159
// TODO: Refactor includes so we can just "#include".
@@ -870,6 +869,98 @@ auto ConvertImpl(std::byte val) {
870869
}
871870
#endif
872871

872+
// We interpret bool as int8_t, std::byte as uint8_t for conversion to other
873+
// types.
874+
template <typename T>
875+
using ConvertBoolAndByteT =
876+
typename detail::map_type<T,
877+
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
878+
std::byte, /*->*/ std::uint8_t, //
879+
#endif
880+
bool, /*->*/ std::uint8_t, //
881+
T, /*->*/ T //
882+
>::type;
873883
} // namespace detail
884+
885+
template <typename DataT, int NumElements>
886+
template <typename convertT, rounding_mode roundingMode>
887+
vec<convertT, NumElements> vec<DataT, NumElements>::convert() const {
888+
using T = detail::ConvertBoolAndByteT<DataT>;
889+
using R = detail::ConvertBoolAndByteT<convertT>;
890+
using bfloat16 = sycl::ext::oneapi::bfloat16;
891+
static_assert(std::is_integral_v<R> || detail::is_floating_point<R>::value ||
892+
std::is_same_v<R, bfloat16>,
893+
"Unsupported convertT");
894+
895+
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
896+
using OpenCLR = detail::ConvertToOpenCLType_t<R>;
897+
vec<convertT, NumElements> Result;
898+
899+
// convertImpl can't be called with the same From and To types and therefore
900+
// we need some special processing in a few cases.
901+
if constexpr (std::is_same_v<DataT, convertT>) {
902+
return *this;
903+
} else if constexpr (std::is_same_v<OpenCLT, OpenCLR> ||
904+
std::is_same_v<T, R>) {
905+
for (size_t I = 0; I < NumElements; ++I)
906+
Result[I] = static_cast<convertT>(getValue(I));
907+
return Result;
908+
} else {
909+
910+
#ifdef __SYCL_DEVICE_ONLY__
911+
using OpenCLVecT = OpenCLT __attribute__((ext_vector_type(NumElements)));
912+
using OpenCLVecR = OpenCLR __attribute__((ext_vector_type(NumElements)));
913+
914+
auto NativeVector = sycl::bit_cast<vector_t>(*this);
915+
using ConvertTVecType = typename vec<convertT, NumElements>::vector_t;
916+
917+
// Whole vector conversion can only be done, if:
918+
constexpr bool canUseNativeVectorConvert =
919+
#ifdef __NVPTX__
920+
// TODO: Likely unnecessary as
921+
// https://github.com/intel/llvm/issues/11840 has been closed
922+
// already.
923+
false &&
924+
#endif
925+
NumElements > 1 &&
926+
// - vec storage has an equivalent OpenCL native vector it is
927+
// implicitly convertible to. There are some corner cases where it
928+
// is not the case with char, long and long long types.
929+
std::is_convertible_v<vector_t, OpenCLVecT> &&
930+
std::is_convertible_v<ConvertTVecType, OpenCLVecR> &&
931+
// - it is not a signed to unsigned (or vice versa) conversion
932+
// see comments within 'convertImpl' for more details;
933+
!detail::is_sint_to_from_uint<T, R>::value &&
934+
// - destination type is not bool. bool is stored as integer under the
935+
// hood and therefore conversion to bool looks like conversion
936+
// between two integer types. Since bit pattern for true and false
937+
// is not defined, there is no guarantee that integer conversion
938+
// yields right results here;
939+
!std::is_same_v<convertT, bool>;
940+
941+
if constexpr (canUseNativeVectorConvert) {
942+
auto val = detail::convertImpl<T, R, roundingMode, NumElements,
943+
OpenCLVecT, OpenCLVecR>(NativeVector);
944+
Result.m_Data = sycl::bit_cast<decltype(Result.m_Data)>(val);
945+
} else
946+
#endif // __SYCL_DEVICE_ONLY__
947+
{
948+
// Otherwise, we fallback to per-element conversion:
949+
for (size_t I = 0; I < NumElements; ++I) {
950+
auto val = detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
951+
getValue(I));
952+
#ifdef __SYCL_DEVICE_ONLY__
953+
// On device, we interpret BF16 as uint16.
954+
if constexpr (std::is_same_v<convertT, bfloat16>)
955+
Result[I] = sycl::bit_cast<convertT>(val);
956+
else
957+
#endif
958+
Result[I] = static_cast<convertT>(val);
959+
}
960+
}
961+
}
962+
return Result;
963+
}
964+
874965
} // namespace _V1
875966
} // namespace sycl

sycl/include/sycl/types.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include <sycl/half_type.hpp> // for StorageT, half, Vec16...
2222
#include <sycl/marray.hpp> // for __SYCL_BINOP, __SYCL_...
2323
#include <sycl/multi_ptr.hpp> // for multi_ptr
24-
#include <sycl/vector.hpp> // for sycl::vec and swizzles
24+
25+
#include <sycl/vector.hpp>
26+
27+
#include <sycl/detail/vector_convert.hpp>
2528

2629
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

sycl/include/sycl/vector.hpp

Lines changed: 6 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include <sycl/detail/type_list.hpp> // for is_contained
3737
#include <sycl/detail/type_traits.hpp> // for is_floating_point
3838
#include <sycl/detail/vector_arith.hpp>
39-
#include <sycl/detail/vector_convert.hpp> // for convertImpl
4039
#include <sycl/half_type.hpp> // for StorageT, half, Vec16...
4140

4241
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
@@ -53,6 +52,10 @@
5352
#include <utility> // for index_sequence, make_...
5453

5554
namespace sycl {
55+
56+
// TODO: Fix in the next ABI breaking windows.
57+
enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };
58+
5659
inline namespace _V1 {
5760

5861
struct elem {
@@ -406,18 +409,6 @@ class __SYCL_EBO vec
406409
static constexpr size_t byte_size() noexcept { return sizeof(m_Data); }
407410

408411
private:
409-
// We interpret bool as int8_t, std::byte as uint8_t for conversion to other
410-
// types.
411-
template <typename T>
412-
using ConvertBoolAndByteT =
413-
typename detail::map_type<T,
414-
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
415-
std::byte, /*->*/ std::uint8_t, //
416-
#endif
417-
bool, /*->*/ std::uint8_t, //
418-
T, /*->*/ T //
419-
>::type;
420-
421412
// getValue should be able to operate on different underlying
422413
// types: enum cl_float#N , builtin vector float#N, builtin type float.
423414
constexpr auto getValue(int Index) const {
@@ -439,88 +430,10 @@ class __SYCL_EBO vec
439430
}
440431

441432
public:
433+
// Out-of-class definition is in `sycl/detail/vector_convert.hpp`
442434
template <typename convertT,
443435
rounding_mode roundingMode = rounding_mode::automatic>
444-
vec<convertT, NumElements> convert() const {
445-
446-
using T = ConvertBoolAndByteT<DataT>;
447-
using R = ConvertBoolAndByteT<convertT>;
448-
using bfloat16 = sycl::ext::oneapi::bfloat16;
449-
static_assert(std::is_integral_v<R> ||
450-
detail::is_floating_point<R>::value ||
451-
std::is_same_v<R, bfloat16>,
452-
"Unsupported convertT");
453-
454-
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
455-
using OpenCLR = detail::ConvertToOpenCLType_t<R>;
456-
vec<convertT, NumElements> Result;
457-
458-
// convertImpl can't be called with the same From and To types and therefore
459-
// we need some special processing in a few cases.
460-
if constexpr (std::is_same_v<DataT, convertT>) {
461-
return *this;
462-
} else if constexpr (std::is_same_v<OpenCLT, OpenCLR> ||
463-
std::is_same_v<T, R>) {
464-
for (size_t I = 0; I < NumElements; ++I)
465-
Result[I] = static_cast<convertT>(getValue(I));
466-
return Result;
467-
} else {
468-
469-
#ifdef __SYCL_DEVICE_ONLY__
470-
using OpenCLVecT = OpenCLT __attribute__((ext_vector_type(NumElements)));
471-
using OpenCLVecR = OpenCLR __attribute__((ext_vector_type(NumElements)));
472-
473-
auto NativeVector = sycl::bit_cast<vector_t>(*this);
474-
using ConvertTVecType = typename vec<convertT, NumElements>::vector_t;
475-
476-
// Whole vector conversion can only be done, if:
477-
constexpr bool canUseNativeVectorConvert =
478-
#ifdef __NVPTX__
479-
// TODO: Likely unnecessary as
480-
// https://github.com/intel/llvm/issues/11840 has been closed
481-
// already.
482-
false &&
483-
#endif
484-
NumElements > 1 &&
485-
// - vec storage has an equivalent OpenCL native vector it is
486-
// implicitly convertible to. There are some corner cases where it
487-
// is not the case with char, long and long long types.
488-
std::is_convertible_v<vector_t, OpenCLVecT> &&
489-
std::is_convertible_v<ConvertTVecType, OpenCLVecR> &&
490-
// - it is not a signed to unsigned (or vice versa) conversion
491-
// see comments within 'convertImpl' for more details;
492-
!detail::is_sint_to_from_uint<T, R>::value &&
493-
// - destination type is not bool. bool is stored as integer under the
494-
// hood and therefore conversion to bool looks like conversion
495-
// between two integer types. Since bit pattern for true and false
496-
// is not defined, there is no guarantee that integer conversion
497-
// yields right results here;
498-
!std::is_same_v<convertT, bool>;
499-
500-
if constexpr (canUseNativeVectorConvert) {
501-
auto val = detail::convertImpl<T, R, roundingMode, NumElements, OpenCLVecT,
502-
OpenCLVecR>(NativeVector);
503-
Result.m_Data = sycl::bit_cast<decltype(Result.m_Data)>(val);
504-
} else
505-
#endif // __SYCL_DEVICE_ONLY__
506-
{
507-
// Otherwise, we fallback to per-element conversion:
508-
for (size_t I = 0; I < NumElements; ++I) {
509-
auto val =
510-
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
511-
getValue(I));
512-
#ifdef __SYCL_DEVICE_ONLY__
513-
// On device, we interpret BF16 as uint16.
514-
if constexpr (std::is_same_v<convertT, bfloat16>)
515-
Result[I] = sycl::bit_cast<convertT>(val);
516-
else
517-
#endif
518-
Result[I] = static_cast<convertT>(val);
519-
}
520-
}
521-
}
522-
return Result;
523-
}
436+
vec<convertT, NumElements> convert() const;
524437

525438
template <typename asT> asT as() const { return sycl::bit_cast<asT>(*this); }
526439

0 commit comments

Comments
 (0)