Skip to content

Commit d9e40ec

Browse files
authored
[SYCL][ESIMD] Add support for different types for lsc functions (#6952)
1 parent 6b24fdc commit d9e40ec

File tree

3 files changed

+81
-35
lines changed

3 files changed

+81
-35
lines changed

llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ static const char *LegalSYCLFunctions[] = {
5151
"^sycl::_V1::ext::oneapi::sub_group::.+",
5252
"^sycl::_V1::ext::oneapi::experimental::spec_constant<.+>::.+",
5353
"^sycl::_V1::ext::oneapi::experimental::this_sub_group",
54-
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+",
55-
"^sycl::_V1::ext::oneapi::experimental::tfloat32::.+"};
54+
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+"};
5655

5756
static const char *LegalSYCLFunctionsInStatelessMode[] = {
5857
"^sycl::_V1::multi_ptr<.+>::get",

sycl/include/sycl/ext/intel/experimental/esimd/common.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,20 @@ constexpr lsc_data_size expand_data_size(lsc_data_size DS) {
185185
}
186186

187187
template <typename T> struct lsc_expand_type {
188-
using type = typename std::conditional<sizeof(T) < 4, uint32_t, T>::type;
188+
using type = std::conditional_t<
189+
sizeof(T) <= 4,
190+
std::conditional_t<std::is_signed<T>::value, int32_t, uint32_t>,
191+
std::conditional_t<std::is_signed<T>::value, int64_t, uint64_t>>;
189192
};
190193

191194
template <typename T> struct lsc_bitcast_type {
192-
private:
193-
using _type1 = typename std::conditional<sizeof(T) == 2, uint16_t, T>::type;
194-
using _type2 = typename std::conditional<sizeof(T) == 1, uint8_t, T>::type;
195-
196195
public:
197-
using type =
198-
typename std::conditional<sizeof(_type2) == 1, _type2, _type1>::type;
196+
using type = std::conditional_t<
197+
sizeof(T) == 1, uint8_t,
198+
std::conditional_t<
199+
sizeof(T) == 2, uint16_t,
200+
std::conditional_t<sizeof(T) == 4, uint32_t,
201+
std::conditional_t<sizeof(T) == 8, uint64_t, T>>>>;
199202
};
200203

201204
} // namespace detail

sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,8 @@ __ESIMD_API std::enable_if_t<!std::is_pointer<AccessorTy>::value,
449449
lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
450450
__ESIMD_NS::simd_mask<N> pred = 1) {
451451
#ifdef __ESIMD_FORCE_STATELESS_MEM
452-
return lsc_gather<T, N, DS, L1H>(acc.get_pointer().get(), offsets, pred);
452+
return lsc_gather<T, NElts, DS, L1H, L3H>(acc.get_pointer().get(), offsets,
453+
pred);
453454
#else
454455
detail::check_lsc_vector_size<NElts>();
455456
detail::check_lsc_data_size<T, DS>();
@@ -478,11 +479,11 @@ lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
478479
/// given address, where S is a byte size of an "element" defined by the \c DS
479480
/// template parameter. The maximum size of accessed block is 512 bytes for PVC
480481
/// and 256 bytes for ACM (DG2).
481-
/// When \? DS equals \? lsc_data_size::u64, the address must be 8-byte aligned,
482+
/// When \c DS equals \c lsc_data_size::u64, the address must be 8-byte aligned,
482483
/// otherwise - 4-bytes aligned. Allowed values for the data size are
483-
/// \? lsc_data_size::u32 and \? lsc_data_size::u64. Allowed NElts values are
484+
/// \c lsc_data_size::u32 and \c lsc_data_size::u64. Allowed NElts values are
484485
/// 1, 2, 3, 4, 8, 16, 32, 64.
485-
/// Note that to access 512 bytes, DS must be \? lsc_data_size::u64 and \c NElts
486+
/// Note that to access 512 bytes, DS must be \c lsc_data_size::u64 and \c NElts
486487
/// must be 64.
487488
///
488489
/// @tparam T is element type.
@@ -518,9 +519,19 @@ lsc_block_load(const T *p, __ESIMD_NS::simd_mask<1> pred = 1) {
518519
constexpr detail::lsc_vector_size _VS =
519520
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
520521
if constexpr (SmallIntFactor == 1) {
521-
return __esimd_lsc_load_stateless<T, L1H, L3H, _AddressScale, _ImmOffset,
522-
_DS, _VS, _Transposed, N>(pred.data(),
523-
addrs.data());
522+
if constexpr (_DS == lsc_data_size::u32) {
523+
__ESIMD_NS::simd<uint32_t, NElts> result =
524+
__esimd_lsc_load_stateless<uint32_t, L1H, L3H, _AddressScale,
525+
_ImmOffset, lsc_data_size::u32, _VS,
526+
_Transposed, N>(pred.data(), addrs.data());
527+
return result.template bit_cast_view<T>();
528+
} else {
529+
__ESIMD_NS::simd<uint64_t, NElts> result =
530+
__esimd_lsc_load_stateless<uint64_t, L1H, L3H, _AddressScale,
531+
_ImmOffset, lsc_data_size::u64, _VS,
532+
_Transposed, N>(pred.data(), addrs.data());
533+
return result.template bit_cast_view<T>();
534+
}
524535
} else {
525536
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> result =
526537
__esimd_lsc_load_stateless<uint32_t, L1H, L3H, _AddressScale,
@@ -582,11 +593,20 @@ lsc_block_load(AccessorTy acc, uint32_t offset,
582593
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
583594

584595
if constexpr (SmallIntFactor == 1) {
585-
return __esimd_lsc_load_bti<T, L1H, L3H, _AddressScale, _ImmOffset, _DS,
586-
_VS, _Transposed, N>(pred.data(),
587-
offsets.data(), si);
596+
if constexpr (_DS == lsc_data_size::u32) {
597+
__ESIMD_NS::simd<uint32_t, NElts> result =
598+
__esimd_lsc_load_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
599+
lsc_data_size::u32, _VS, _Transposed, N>(
600+
pred.data(), offsets.data(), si);
601+
return result.template bit_cast_view<T>();
602+
} else {
603+
__ESIMD_NS::simd<uint64_t, NElts> result =
604+
__esimd_lsc_load_bti<uint64_t, L1H, L3H, _AddressScale, _ImmOffset,
605+
lsc_data_size::u64, _VS, _Transposed, N>(
606+
pred.data(), offsets.data(), si);
607+
return result.template bit_cast_view<T>();
608+
}
588609
} else {
589-
590610
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> result =
591611
__esimd_lsc_load_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
592612
lsc_data_size::u32, _VS, _Transposed, N>(
@@ -904,8 +924,8 @@ lsc_scatter(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
904924
__ESIMD_NS::simd<T, N * NElts> vals,
905925
__ESIMD_NS::simd_mask<N> pred = 1) {
906926
#ifdef __ESIMD_FORCE_STATELESS_MEM
907-
lsc_scatter<T, NElts, DS, L1H>(__ESIMD_DNS::accessorToPointer<T>(acc),
908-
offsets, pred);
927+
lsc_scatter<T, NElts, DS, L1H, L3H>(__ESIMD_DNS::accessorToPointer<T>(acc),
928+
offsets, vals, pred);
909929
#else
910930
detail::check_lsc_vector_size<NElts>();
911931
detail::check_lsc_data_size<T, DS>();
@@ -967,13 +987,23 @@ __ESIMD_API void lsc_block_store(T *p, __ESIMD_NS::simd<T, NElts> vals,
967987
constexpr detail::lsc_vector_size _VS =
968988
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
969989
if constexpr (SmallIntFactor == 1) {
970-
971-
__esimd_lsc_store_stateless<T, L1H, L3H, _AddressScale, _ImmOffset, _DS,
972-
_VS, _Transposed, N>(pred.data(), addrs.data(),
973-
vals.data());
990+
if constexpr (_DS == lsc_data_size::u32) {
991+
__esimd_lsc_store_stateless<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
992+
_DS, _VS, _Transposed, N>(
993+
pred.data(), addrs.data(),
994+
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint32_t, NElts>>(
995+
vals.data()));
996+
} else {
997+
__esimd_lsc_store_stateless<uint64_t, L1H, L3H, _AddressScale, _ImmOffset,
998+
_DS, _VS, _Transposed, N>(
999+
pred.data(), addrs.data(),
1000+
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint64_t, NElts>>(
1001+
vals.data()));
1002+
}
9741003
} else {
975-
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> tmp =
976-
vals.template bit_cast_view<uint32_t>();
1004+
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> tmp = sycl::bit_cast<
1005+
__ESIMD_DNS::vector_type_t<uint32_t, NElts / SmallIntFactor>>(
1006+
vals.data());
9771007

9781008
__esimd_lsc_store_stateless<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
9791009
lsc_data_size::u32, _VS, _Transposed, N>(
@@ -1010,7 +1040,7 @@ lsc_block_store(AccessorTy acc, uint32_t offset,
10101040
__ESIMD_NS::simd<T, NElts> vals,
10111041
__ESIMD_NS::simd_mask<1> pred = 1) {
10121042
#ifdef __ESIMD_FORCE_STATELESS_MEM
1013-
lsc_block_store<T, NElts, DS, L1H>(
1043+
lsc_block_store<T, NElts, DS, L1H, L3H>(
10141044
__ESIMD_DNS::accessorToPointer<T>(acc, offset), vals, pred);
10151045
#else
10161046
detail::check_lsc_data_size<T, DS>();
@@ -1033,15 +1063,29 @@ lsc_block_store(AccessorTy acc, uint32_t offset,
10331063
constexpr detail::lsc_vector_size _VS =
10341064
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
10351065
if constexpr (SmallIntFactor > 1) {
1036-
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> Tmp =
1037-
vals.template bit_cast_view<uint32_t>();
10381066
__esimd_lsc_store_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
10391067
lsc_data_size::u32, _VS, _Transposed, N>(
1040-
pred.data(), offsets.data(), Tmp.data(), si);
1068+
pred.data(), offsets.data(),
1069+
sycl::bit_cast<
1070+
__ESIMD_DNS::vector_type_t<uint32_t, NElts / SmallIntFactor>>(
1071+
vals.data()),
1072+
si);
10411073
} else {
1042-
__esimd_lsc_store_bti<T, L1H, L3H, _AddressScale, _ImmOffset, _DS, _VS,
1043-
_Transposed, N>(pred.data(), offsets.data(),
1044-
vals.data(), si);
1074+
if constexpr (_DS == lsc_data_size::u32) {
1075+
__esimd_lsc_store_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset, _DS,
1076+
_VS, _Transposed, N>(
1077+
pred.data(), offsets.data(),
1078+
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint32_t, NElts>>(
1079+
vals.data()),
1080+
si);
1081+
} else {
1082+
__esimd_lsc_store_bti<uint64_t, L1H, L3H, _AddressScale, _ImmOffset, _DS,
1083+
_VS, _Transposed, N>(
1084+
pred.data(), offsets.data(),
1085+
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint64_t, NElts>>(
1086+
vals.data()),
1087+
si);
1088+
}
10451089
}
10461090
#endif
10471091
}

0 commit comments

Comments
 (0)