Skip to content

Commit 134618f

Browse files
authored
[SYCL][CUDA] Allow joint_matrix to be loaded from const T (#6532)
Fixes a bug where if `joint_matrix_load` attempts to load `joint_matrix` from an array of `const T`incorrect behaviour will occur or an error will be thrown. To fix this we make use of `std::remove_const_t<T>` in appropriate places. This is important functionality for integrating joint_matrix with existing SYCL-DNN routines. I think that similar problems might occur in the intel backends for their existing impl: I have not made corresponding changes because I do not have the hardware to test it. Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
1 parent 813ca36 commit 134618f

File tree

1 file changed

+55
-53
lines changed

1 file changed

+55
-53
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,11 @@ struct joint_matrix_load_impl<
219219
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
220220
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
221221
multi_ptr<T, Space> src, size_t stride) {
222-
if constexpr (std::is_same<T, uint16_t>::value ||
222+
if constexpr (std::is_same<std::remove_const_t<T>, uint16_t>::value ||
223223
std::is_same<
224-
T, sycl::ext::oneapi::experimental::bfloat16>::value) {
225-
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
224+
std::remove_const_t<T>,
225+
sycl::ext::oneapi::experimental::bfloat16>::value) {
226+
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
226227
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
227228
if constexpr (NumRows == 16 && NumCols == 16) {
228229
if constexpr (Use ==
@@ -247,8 +248,8 @@ struct joint_matrix_load_impl<
247248
__mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride,
248249
get_layout_id<Layout>());
249250
}
250-
} else if constexpr (std::is_same<T, uint8_t>::value) {
251-
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
251+
} else if constexpr (std::is_same<std::remove_const_t<T>, uint8_t>::value) {
252+
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
252253
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
253254
if constexpr (NumRows == 16 && NumCols == 16) {
254255
if constexpr (Use ==
@@ -273,8 +274,8 @@ struct joint_matrix_load_impl<
273274
__imma_m32n8k16_ld_b_u8(destptr, tileptr, stride,
274275
get_layout_id<Layout>());
275276
}
276-
} else if constexpr (std::is_same<T, int8_t>::value) {
277-
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
277+
} else if constexpr (std::is_same<std::remove_const_t<T>, int8_t>::value) {
278+
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
278279
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
279280
if constexpr (NumRows == 16 && NumCols == 16) {
280281
if constexpr (Use ==
@@ -299,8 +300,8 @@ struct joint_matrix_load_impl<
299300
__imma_m32n8k16_ld_b_s8(destptr, tileptr, stride,
300301
get_layout_id<Layout>());
301302
}
302-
} else if constexpr (std::is_same<T, half>::value) {
303-
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
303+
} else if constexpr (std::is_same<std::remove_const_t<T>, half>::value) {
304+
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
304305
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
305306
if constexpr (NumRows == 16 && NumCols == 16) {
306307
if constexpr (Use ==
@@ -332,7 +333,7 @@ struct joint_matrix_load_impl<
332333
get_layout_id<Layout>());
333334
}
334335

335-
} else if constexpr (std::is_same<T, int32_t>::value) {
336+
} else if constexpr (std::is_same<std::remove_const_t<T>, int32_t>::value) {
336337
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
337338
if constexpr (NumRows == 16 && NumCols == 16) {
338339
__imma_m16n16k16_ld_c(destptr, src.get(), stride,
@@ -344,7 +345,7 @@ struct joint_matrix_load_impl<
344345
__imma_m32n8k16_ld_c(destptr, src.get(), stride,
345346
get_layout_id<Layout>());
346347
}
347-
} else if constexpr (std::is_same<T, float>::value) {
348+
} else if constexpr (std::is_same<std::remove_const_t<T>, float>::value) {
348349
if constexpr (std::is_same<S, float>::value) {
349350
auto dstptr = reinterpret_cast<float *>(&res.wi_marray);
350351
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -360,7 +361,7 @@ struct joint_matrix_load_impl<
360361
} else if constexpr (std::is_same<S,
361362
sycl::ext::oneapi::experimental::
362363
matrix::precision::tf32>::value) {
363-
auto tileptr = reinterpret_cast<int32_t *>(src.get());
364+
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
364365
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
365366
if constexpr (NumRows == 16 && NumCols == 8) {
366367
__mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride,
@@ -370,7 +371,7 @@ struct joint_matrix_load_impl<
370371
get_layout_id<Layout>());
371372
}
372373
}
373-
} else if constexpr (std::is_same<T, double>::value) {
374+
} else if constexpr (std::is_same<std::remove_const_t<T>, double>::value) {
374375
auto dstptr = reinterpret_cast<double *>(&res.wi_marray);
375376
if constexpr (Use ==
376377
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
@@ -560,9 +561,9 @@ struct joint_matrix_mad_impl<
560561
D;
561562
if constexpr (M == 16 && N == 16 && K == 16) {
562563
if constexpr (std::is_same<T2, int32_t>::value) {
563-
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
564-
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
565-
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
564+
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
565+
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
566+
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
566567
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
567568
if constexpr (std::is_same<T1, int8_t>::value) {
568569
__imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
@@ -572,34 +573,34 @@ struct joint_matrix_mad_impl<
572573
get_layout_pair_id<LayoutA, LayoutB>(), 0);
573574
}
574575
} else if constexpr (std::is_same<T1, half>::value) {
575-
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
576-
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
576+
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
577+
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
577578
if constexpr (std::is_same<T2, float>::value) {
578579
__hmma_m16n16k16_mma_f32f32(
579580
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
580-
reinterpret_cast<float const *>(&C.wi_marray),
581+
reinterpret_cast<const float *>(&C.wi_marray),
581582
get_layout_pair_id<LayoutA, LayoutB>(), 0);
582583
} else if constexpr (std::is_same<T2, half>::value) {
583584
__hmma_m16n16k16_mma_f16f16(
584585
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
585-
reinterpret_cast<int32_t const *>(&C.wi_marray),
586+
reinterpret_cast<const int32_t *>(&C.wi_marray),
586587
get_layout_pair_id<LayoutA, LayoutB>(), 0);
587588
}
588589
} else if constexpr (std::is_same<T1, uint16_t>::value ||
589590
std::is_same<T1, sycl::ext::oneapi::experimental::
590591
bfloat16>::value) {
591592
__mma_bf16_m16n16k16_mma_f32(
592593
reinterpret_cast<float *>(&D.wi_marray),
593-
reinterpret_cast<int32_t const *>(&A.wi_marray),
594-
reinterpret_cast<int32_t const *>(&B.wi_marray),
595-
reinterpret_cast<float const *>(&C.wi_marray),
594+
reinterpret_cast<const int32_t *>(&A.wi_marray),
595+
reinterpret_cast<const int32_t *>(&B.wi_marray),
596+
reinterpret_cast<const float *>(&C.wi_marray),
596597
get_layout_pair_id<LayoutA, LayoutB>(), 0);
597598
}
598599
} else if constexpr (M == 8 && N == 32 && K == 16) {
599600
if constexpr (std::is_same<T2, int32_t>::value) {
600-
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
601-
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
602-
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
601+
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
602+
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
603+
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
603604
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
604605
if constexpr (std::is_same<T1, int8_t>::value) {
605606
__imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
@@ -609,34 +610,34 @@ struct joint_matrix_mad_impl<
609610
get_layout_pair_id<LayoutA, LayoutB>(), 0);
610611
}
611612
} else if constexpr (std::is_same<T1, half>::value) {
612-
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
613-
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
613+
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
614+
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
614615
if constexpr (std::is_same<T2, float>::value) {
615616
__hmma_m8n32k16_mma_f32f32(
616617
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
617-
reinterpret_cast<float const *>(&C.wi_marray),
618+
reinterpret_cast<const float *>(&C.wi_marray),
618619
get_layout_pair_id<LayoutA, LayoutB>(), 0);
619620
} else if constexpr (std::is_same<T2, half>::value) {
620621
__hmma_m8n32k16_mma_f16f16(
621622
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
622-
reinterpret_cast<int32_t const *>(&C.wi_marray),
623+
reinterpret_cast<const int32_t *>(&C.wi_marray),
623624
get_layout_pair_id<LayoutA, LayoutB>(), 0);
624625
}
625626
} else if constexpr (std::is_same<T1, uint16_t>::value ||
626627
std::is_same<T1, sycl::ext::oneapi::experimental::
627628
bfloat16>::value) {
628629
__mma_bf16_m8n32k16_mma_f32(
629630
reinterpret_cast<float *>(&D.wi_marray),
630-
reinterpret_cast<int32_t const *>(&A.wi_marray),
631-
reinterpret_cast<int32_t const *>(&B.wi_marray),
632-
reinterpret_cast<float const *>(&C.wi_marray),
631+
reinterpret_cast<const int32_t *>(&A.wi_marray),
632+
reinterpret_cast<const int32_t *>(&B.wi_marray),
633+
reinterpret_cast<const float *>(&C.wi_marray),
633634
get_layout_pair_id<LayoutA, LayoutB>(), 0);
634635
}
635636
} else if constexpr (M == 32 && N == 8 && K == 16) {
636637
if constexpr (std::is_same<T2, int32_t>::value) {
637-
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
638-
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
639-
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
638+
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
639+
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
640+
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
640641
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
641642
if constexpr (std::is_same<T1, int8_t>::value) {
642643
__imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
@@ -650,22 +651,22 @@ struct joint_matrix_mad_impl<
650651
bfloat16>::value) {
651652
__mma_bf16_m32n8k16_mma_f32(
652653
reinterpret_cast<float *>(&D.wi_marray),
653-
reinterpret_cast<int32_t const *>(&A.wi_marray),
654-
reinterpret_cast<int32_t const *>(&B.wi_marray),
655-
reinterpret_cast<float const *>(&C.wi_marray),
654+
reinterpret_cast<const int32_t *>(&A.wi_marray),
655+
reinterpret_cast<const int32_t *>(&B.wi_marray),
656+
reinterpret_cast<const float *>(&C.wi_marray),
656657
get_layout_pair_id<LayoutA, LayoutB>(), 0);
657658
} else if constexpr (std::is_same<T1, half>::value) {
658-
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
659-
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
659+
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
660+
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
660661
if constexpr (std::is_same<T2, float>::value) {
661662
__hmma_m32n8k16_mma_f32f32(
662663
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
663-
reinterpret_cast<float const *>(&C.wi_marray),
664+
reinterpret_cast<const float *>(&C.wi_marray),
664665
get_layout_pair_id<LayoutA, LayoutB>(), 0);
665666
} else if constexpr (std::is_same<T2, half>::value) {
666667
__hmma_m32n8k16_mma_f16f16(
667668
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
668-
reinterpret_cast<int32_t const *>(&C.wi_marray),
669+
reinterpret_cast<const int32_t *>(&C.wi_marray),
669670
get_layout_pair_id<LayoutA, LayoutB>(), 0);
670671
}
671672
}
@@ -677,9 +678,9 @@ struct joint_matrix_mad_impl<
677678
get_layout_pair_id<LayoutA, LayoutB>(), 0);
678679
} else if constexpr (std::is_same<T1, double>::value) {
679680
__dmma_m8n8k4_mma_f64(reinterpret_cast<double *>(&D.wi_marray),
680-
reinterpret_cast<double const *>(&A.wi_marray),
681-
reinterpret_cast<double const *>(&B.wi_marray),
682-
reinterpret_cast<double const *>(&C.wi_marray),
681+
reinterpret_cast<const double *>(&A.wi_marray),
682+
reinterpret_cast<const double *>(&B.wi_marray),
683+
reinterpret_cast<const double *>(&C.wi_marray),
683684
get_layout_pair_id<LayoutA, LayoutB>(), 0);
684685
}
685686
return D;
@@ -692,13 +693,14 @@ struct joint_matrix_mad_impl<
692693
namespace experimental {
693694
namespace matrix {
694695

695-
template <typename Group, typename S, typename T, matrix_use Use,
696-
size_t NumRows, size_t NumCols, matrix_layout Layout,
697-
access::address_space Space,
698-
std::enable_if_t<std::is_same<S, T>::value ||
699-
(std::is_same<S, precision::tf32>::value &&
700-
std::is_same<T, float>::value),
701-
bool> = true>
696+
template <
697+
typename Group, typename S, typename T, matrix_use Use, size_t NumRows,
698+
size_t NumCols, matrix_layout Layout, access::address_space Space,
699+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
700+
(std::is_same<S, precision::tf32>::value &&
701+
702+
std::is_same<std::remove_const_t<T>, float>::value),
703+
bool> = true>
702704
void joint_matrix_load(
703705
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
704706
multi_ptr<T, Space> src, size_t stride) {

0 commit comments

Comments
 (0)