Skip to content

Commit 90d898b

Browse files
authored
[ESIMD] Change LSC API to improve template argument type deduction (#6764)
1 parent c3ea03d commit 90d898b

File tree

2 files changed

+24
-33
lines changed

2 files changed

+24
-33
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ template <lsc_vector_size VS> constexpr void check_lsc_vector_size() {
108108
"Unsupported vector size");
109109
}
110110

111-
template <uint8_t VS> constexpr void check_lsc_vector_size() {
111+
template <int VS> constexpr void check_lsc_vector_size() {
112112
static_assert(VS == 1 || VS == 2 || VS == 3 || VS == 4 || VS == 8 ||
113113
VS == 16 || VS == 32 || VS == 64,
114114
"Unsupported vector size");
@@ -144,7 +144,7 @@ template <lsc_vector_size VS> constexpr uint8_t to_int() {
144144
}
145145
}
146146

147-
template <uint8_t VS> constexpr lsc_vector_size to_lsc_vector_size() {
147+
template <int VS> constexpr lsc_vector_size to_lsc_vector_size() {
148148
check_lsc_vector_size<VS>();
149149
switch (VS) {
150150
case 1:

sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ constexpr void check_lsc_atomic() {
322322
/// @param pred is predicates.
323323
/// @return is a vector of type T and size N * NElts
324324
///
325-
template <typename T, uint8_t NElts = 1,
325+
template <typename T, int NElts = 1,
326326
lsc_data_size DS = lsc_data_size::default_size, int N>
327327
__ESIMD_API __ESIMD_NS::simd<T, N * NElts>
328328
lsc_slm_gather(__ESIMD_NS::simd<uint32_t, N> offsets,
@@ -357,8 +357,7 @@ lsc_slm_gather(__ESIMD_NS::simd<uint32_t, N> offsets,
357357
/// @param offset is the zero-based offset for SLM buffer in bytes.
358358
/// @return is a vector of type T and size NElts
359359
///
360-
template <typename T, uint8_t NElts = 1,
361-
lsc_data_size DS = lsc_data_size::default_size>
360+
template <typename T, int NElts, lsc_data_size DS = lsc_data_size::default_size>
362361
__ESIMD_API __ESIMD_NS::simd<T, NElts> lsc_slm_block_load(uint32_t offset) {
363362
detail::check_lsc_vector_size<NElts>();
364363
detail::check_lsc_data_size<T, DS>();
@@ -396,10 +395,9 @@ __ESIMD_API __ESIMD_NS::simd<T, NElts> lsc_slm_block_load(uint32_t offset) {
396395
/// @param pred is predicates.
397396
/// @return is a vector of type T and size N * NElts
398397
///
399-
template <typename T, uint8_t NElts = 1,
400-
lsc_data_size DS = lsc_data_size::default_size,
401-
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
402-
int N>
398+
template <
399+
typename T, int NElts = 1, lsc_data_size DS = lsc_data_size::default_size,
400+
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none, int N>
403401
__ESIMD_API __ESIMD_NS::simd<T, N * NElts>
404402
lsc_gather(const T *p, __ESIMD_NS::simd<uint32_t, N> offsets,
405403
__ESIMD_NS::simd_mask<N> pred = 1) {
@@ -442,7 +440,7 @@ lsc_gather(const T *p, __ESIMD_NS::simd<uint32_t, N> offsets,
442440
/// @param pred is predicates.
443441
/// @return is a vector of type T and size N * NElts
444442
///
445-
template <typename T, uint8_t NElts = 1,
443+
template <typename T, int NElts = 1,
446444
lsc_data_size DS = lsc_data_size::default_size,
447445
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
448446
int N, typename AccessorTy>
@@ -490,8 +488,7 @@ lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
490488
/// the operation.
491489
/// @return is a vector of type T and size NElts
492490
///
493-
template <typename T, uint8_t NElts = 1,
494-
lsc_data_size DS = lsc_data_size::default_size,
491+
template <typename T, int NElts, lsc_data_size DS = lsc_data_size::default_size,
495492
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none>
496493
__ESIMD_API __ESIMD_NS::simd<T, NElts>
497494
lsc_block_load(const T *p, __ESIMD_NS::simd_mask<1> pred = 1) {
@@ -533,8 +530,7 @@ lsc_block_load(const T *p, __ESIMD_NS::simd_mask<1> pred = 1) {
533530
/// the operation.
534531
/// @return is a vector of type T and size NElts
535532
///
536-
template <typename T, uint8_t NElts = 1,
537-
lsc_data_size DS = lsc_data_size::default_size,
533+
template <typename T, int NElts, lsc_data_size DS = lsc_data_size::default_size,
538534
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
539535
typename AccessorTy>
540536
__ESIMD_API std::enable_if_t<!std::is_pointer<AccessorTy>::value,
@@ -580,10 +576,9 @@ lsc_block_load(AccessorTy acc, uint32_t offset,
580576
/// @param offsets is the zero-based offsets in bytes.
581577
/// @param pred is predicates.
582578
///
583-
template <typename T, uint8_t NElts = 1,
584-
lsc_data_size DS = lsc_data_size::default_size,
585-
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
586-
int N>
579+
template <
580+
typename T, int NElts = 1, lsc_data_size DS = lsc_data_size::default_size,
581+
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none, int N>
587582
__ESIMD_API void lsc_prefetch(const T *p, __ESIMD_NS::simd<uint32_t, N> offsets,
588583
__ESIMD_NS::simd_mask<N> pred = 1) {
589584
detail::check_lsc_vector_size<NElts>();
@@ -617,7 +612,7 @@ __ESIMD_API void lsc_prefetch(const T *p, __ESIMD_NS::simd<uint32_t, N> offsets,
617612
/// @tparam L3H is L3 cache hint.
618613
/// @param p is the base pointer.
619614
///
620-
template <typename T, uint8_t NElts = 1,
615+
template <typename T, int NElts = 1,
621616
lsc_data_size DS = lsc_data_size::default_size,
622617
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none>
623618
__ESIMD_API void lsc_prefetch(const T *p) {
@@ -658,7 +653,7 @@ __ESIMD_API void lsc_prefetch(const T *p) {
658653
/// @param offsets is the zero-based offsets in bytes.
659654
/// @param pred is predicates.
660655
///
661-
template <typename T, uint8_t NElts = 1,
656+
template <typename T, int NElts = 1,
662657
lsc_data_size DS = lsc_data_size::default_size,
663658
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
664659
int N, typename AccessorTy>
@@ -701,7 +696,7 @@ lsc_prefetch(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
701696
/// @param acc is the SYCL accessor.
702697
/// @param offset is the zero-based offset in bytes.
703698
///
704-
template <typename T, uint8_t NElts = 1,
699+
template <typename T, int NElts = 1,
705700
lsc_data_size DS = lsc_data_size::default_size,
706701
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
707702
typename AccessorTy>
@@ -746,7 +741,7 @@ lsc_prefetch(AccessorTy acc, uint32_t offset) {
746741
/// @param vals is values to store.
747742
/// @param pred is predicates.
748743
///
749-
template <typename T, uint8_t NElts = 1,
744+
template <typename T, int NElts = 1,
750745
lsc_data_size DS = lsc_data_size::default_size, int N>
751746
__ESIMD_API void lsc_slm_scatter(__ESIMD_NS::simd<uint32_t, N> offsets,
752747
__ESIMD_NS::simd<T, N * NElts> vals,
@@ -780,8 +775,7 @@ __ESIMD_API void lsc_slm_scatter(__ESIMD_NS::simd<uint32_t, N> offsets,
780775
/// @param offset is the zero-based offset for SLM buffer in bytes.
781776
/// @param vals is values to store.
782777
///
783-
template <typename T, uint8_t NElts = 1,
784-
lsc_data_size DS = lsc_data_size::default_size>
778+
template <typename T, int NElts, lsc_data_size DS = lsc_data_size::default_size>
785779
__ESIMD_API void lsc_slm_block_store(uint32_t offset,
786780
__ESIMD_NS::simd<T, NElts> vals) {
787781
detail::check_lsc_vector_size<NElts>();
@@ -819,10 +813,9 @@ __ESIMD_API void lsc_slm_block_store(uint32_t offset,
819813
/// @param vals is values to store.
820814
/// @param pred is predicates.
821815
///
822-
template <typename T, uint8_t NElts = 1,
823-
lsc_data_size DS = lsc_data_size::default_size,
824-
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
825-
int N>
816+
template <
817+
typename T, int NElts = 1, lsc_data_size DS = lsc_data_size::default_size,
818+
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none, int N>
826819
__ESIMD_API void lsc_scatter(T *p, __ESIMD_NS::simd<uint32_t, N> offsets,
827820
__ESIMD_NS::simd<T, N * NElts> vals,
828821
__ESIMD_NS::simd_mask<N> pred = 1) {
@@ -864,7 +857,7 @@ __ESIMD_API void lsc_scatter(T *p, __ESIMD_NS::simd<uint32_t, N> offsets,
864857
/// @param vals is values to store.
865858
/// @param pred is predicates.
866859
///
867-
template <typename T, uint8_t NElts = 1,
860+
template <typename T, int NElts = 1,
868861
lsc_data_size DS = lsc_data_size::default_size,
869862
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
870863
int N, typename AccessorTy>
@@ -913,8 +906,7 @@ lsc_scatter(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
913906
/// entirely, non-zero - operation is performed. The default is '1' - perform
914907
/// the operation.
915908
///
916-
template <typename T, uint8_t NElts = 1,
917-
lsc_data_size DS = lsc_data_size::default_size,
909+
template <typename T, int NElts, lsc_data_size DS = lsc_data_size::default_size,
918910
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none>
919911
__ESIMD_API void lsc_block_store(T *p, __ESIMD_NS::simd<T, NElts> vals,
920912
__ESIMD_NS::simd_mask<1> pred = 1) {
@@ -955,8 +947,7 @@ __ESIMD_API void lsc_block_store(T *p, __ESIMD_NS::simd<T, NElts> vals,
955947
/// entirely, non-zero - operation is performed. The default is '1' - perform
956948
/// the operation.
957949
///
958-
template <typename T, uint8_t NElts = 1,
959-
lsc_data_size DS = lsc_data_size::default_size,
950+
template <typename T, int NElts, lsc_data_size DS = lsc_data_size::default_size,
960951
cache_hint L1H = cache_hint::none, cache_hint L3H = cache_hint::none,
961952
typename AccessorTy>
962953
__ESIMD_API std::enable_if_t<!std::is_pointer<AccessorTy>::value>

0 commit comments

Comments
 (0)