@@ -219,10 +219,11 @@ struct joint_matrix_load_impl<
219
219
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
220
220
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
221
221
multi_ptr<T, Space> src, size_t stride) {
222
- if constexpr (std::is_same<T , uint16_t >::value ||
222
+ if constexpr (std::is_same<std:: remove_const_t <T> , uint16_t >::value ||
223
223
std::is_same<
224
- T, sycl::ext::oneapi::experimental::bfloat16>::value) {
225
- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
224
+ std::remove_const_t <T>,
225
+ sycl::ext::oneapi::experimental::bfloat16>::value) {
226
+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
226
227
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
227
228
if constexpr (NumRows == 16 && NumCols == 16 ) {
228
229
if constexpr (Use ==
@@ -247,8 +248,8 @@ struct joint_matrix_load_impl<
247
248
__mma_bf16_m32n8k16_ld_b (destptr, tileptr, stride,
248
249
get_layout_id<Layout>());
249
250
}
250
- } else if constexpr (std::is_same<T , uint8_t >::value) {
251
- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
251
+ } else if constexpr (std::is_same<std:: remove_const_t <T> , uint8_t >::value) {
252
+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
252
253
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
253
254
if constexpr (NumRows == 16 && NumCols == 16 ) {
254
255
if constexpr (Use ==
@@ -273,8 +274,8 @@ struct joint_matrix_load_impl<
273
274
__imma_m32n8k16_ld_b_u8 (destptr, tileptr, stride,
274
275
get_layout_id<Layout>());
275
276
}
276
- } else if constexpr (std::is_same<T , int8_t >::value) {
277
- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
277
+ } else if constexpr (std::is_same<std:: remove_const_t <T> , int8_t >::value) {
278
+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
278
279
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
279
280
if constexpr (NumRows == 16 && NumCols == 16 ) {
280
281
if constexpr (Use ==
@@ -299,8 +300,8 @@ struct joint_matrix_load_impl<
299
300
__imma_m32n8k16_ld_b_s8 (destptr, tileptr, stride,
300
301
get_layout_id<Layout>());
301
302
}
302
- } else if constexpr (std::is_same<T , half>::value) {
303
- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
303
+ } else if constexpr (std::is_same<std:: remove_const_t <T> , half>::value) {
304
+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
304
305
auto dstptr = reinterpret_cast <int32_t *>(&res.wi_marray );
305
306
if constexpr (NumRows == 16 && NumCols == 16 ) {
306
307
if constexpr (Use ==
@@ -332,7 +333,7 @@ struct joint_matrix_load_impl<
332
333
get_layout_id<Layout>());
333
334
}
334
335
335
- } else if constexpr (std::is_same<T , int32_t >::value) {
336
+ } else if constexpr (std::is_same<std:: remove_const_t <T> , int32_t >::value) {
336
337
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
337
338
if constexpr (NumRows == 16 && NumCols == 16 ) {
338
339
__imma_m16n16k16_ld_c (destptr, src.get (), stride,
@@ -344,7 +345,7 @@ struct joint_matrix_load_impl<
344
345
__imma_m32n8k16_ld_c (destptr, src.get (), stride,
345
346
get_layout_id<Layout>());
346
347
}
347
- } else if constexpr (std::is_same<T , float >::value) {
348
+ } else if constexpr (std::is_same<std:: remove_const_t <T> , float >::value) {
348
349
if constexpr (std::is_same<S, float >::value) {
349
350
auto dstptr = reinterpret_cast <float *>(&res.wi_marray );
350
351
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -360,7 +361,7 @@ struct joint_matrix_load_impl<
360
361
} else if constexpr (std::is_same<S,
361
362
sycl::ext::oneapi::experimental::
362
363
matrix::precision::tf32>::value) {
363
- auto tileptr = reinterpret_cast <int32_t *>(src.get ());
364
+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
364
365
auto dstptr = reinterpret_cast <int32_t *>(&res.wi_marray );
365
366
if constexpr (NumRows == 16 && NumCols == 8 ) {
366
367
__mma_tf32_m16n16k8_ld_a (dstptr, tileptr, stride,
@@ -370,7 +371,7 @@ struct joint_matrix_load_impl<
370
371
get_layout_id<Layout>());
371
372
}
372
373
}
373
- } else if constexpr (std::is_same<T , double >::value) {
374
+ } else if constexpr (std::is_same<std:: remove_const_t <T> , double >::value) {
374
375
auto dstptr = reinterpret_cast <double *>(&res.wi_marray );
375
376
if constexpr (Use ==
376
377
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
@@ -560,9 +561,9 @@ struct joint_matrix_mad_impl<
560
561
D;
561
562
if constexpr (M == 16 && N == 16 && K == 16 ) {
562
563
if constexpr (std::is_same<T2, int32_t >::value) {
563
- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
564
- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
565
- auto ptrC = reinterpret_cast <int32_t const *>(&C.wi_marray );
564
+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
565
+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
566
+ auto ptrC = reinterpret_cast <const int32_t *>(&C.wi_marray );
566
567
auto ptrD = reinterpret_cast <int32_t *>(&D.wi_marray );
567
568
if constexpr (std::is_same<T1, int8_t >::value) {
568
569
__imma_m16n16k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
@@ -572,34 +573,34 @@ struct joint_matrix_mad_impl<
572
573
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
573
574
}
574
575
} else if constexpr (std::is_same<T1, half>::value) {
575
- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
576
- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
576
+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
577
+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
577
578
if constexpr (std::is_same<T2, float >::value) {
578
579
__hmma_m16n16k16_mma_f32f32 (
579
580
reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
580
- reinterpret_cast <float const *>(&C.wi_marray ),
581
+ reinterpret_cast <const float *>(&C.wi_marray ),
581
582
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
582
583
} else if constexpr (std::is_same<T2, half>::value) {
583
584
__hmma_m16n16k16_mma_f16f16 (
584
585
reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
585
- reinterpret_cast <int32_t const *>(&C.wi_marray ),
586
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
586
587
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
587
588
}
588
589
} else if constexpr (std::is_same<T1, uint16_t >::value ||
589
590
std::is_same<T1, sycl::ext::oneapi::experimental::
590
591
bfloat16>::value) {
591
592
__mma_bf16_m16n16k16_mma_f32 (
592
593
reinterpret_cast <float *>(&D.wi_marray ),
593
- reinterpret_cast <int32_t const *>(&A.wi_marray ),
594
- reinterpret_cast <int32_t const *>(&B.wi_marray ),
595
- reinterpret_cast <float const *>(&C.wi_marray ),
594
+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
595
+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
596
+ reinterpret_cast <const float *>(&C.wi_marray ),
596
597
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
597
598
}
598
599
} else if constexpr (M == 8 && N == 32 && K == 16 ) {
599
600
if constexpr (std::is_same<T2, int32_t >::value) {
600
- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
601
- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
602
- auto ptrC = reinterpret_cast <int32_t const *>(&C.wi_marray );
601
+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
602
+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
603
+ auto ptrC = reinterpret_cast <const int32_t *>(&C.wi_marray );
603
604
auto ptrD = reinterpret_cast <int32_t *>(&D.wi_marray );
604
605
if constexpr (std::is_same<T1, int8_t >::value) {
605
606
__imma_m8n32k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
@@ -609,34 +610,34 @@ struct joint_matrix_mad_impl<
609
610
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
610
611
}
611
612
} else if constexpr (std::is_same<T1, half>::value) {
612
- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
613
- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
613
+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
614
+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
614
615
if constexpr (std::is_same<T2, float >::value) {
615
616
__hmma_m8n32k16_mma_f32f32 (
616
617
reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
617
- reinterpret_cast <float const *>(&C.wi_marray ),
618
+ reinterpret_cast <const float *>(&C.wi_marray ),
618
619
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
619
620
} else if constexpr (std::is_same<T2, half>::value) {
620
621
__hmma_m8n32k16_mma_f16f16 (
621
622
reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
622
- reinterpret_cast <int32_t const *>(&C.wi_marray ),
623
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
623
624
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
624
625
}
625
626
} else if constexpr (std::is_same<T1, uint16_t >::value ||
626
627
std::is_same<T1, sycl::ext::oneapi::experimental::
627
628
bfloat16>::value) {
628
629
__mma_bf16_m8n32k16_mma_f32 (
629
630
reinterpret_cast <float *>(&D.wi_marray ),
630
- reinterpret_cast <int32_t const *>(&A.wi_marray ),
631
- reinterpret_cast <int32_t const *>(&B.wi_marray ),
632
- reinterpret_cast <float const *>(&C.wi_marray ),
631
+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
632
+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
633
+ reinterpret_cast <const float *>(&C.wi_marray ),
633
634
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
634
635
}
635
636
} else if constexpr (M == 32 && N == 8 && K == 16 ) {
636
637
if constexpr (std::is_same<T2, int32_t >::value) {
637
- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
638
- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
639
- auto ptrC = reinterpret_cast <int32_t const *>(&C.wi_marray );
638
+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
639
+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
640
+ auto ptrC = reinterpret_cast <const int32_t *>(&C.wi_marray );
640
641
auto ptrD = reinterpret_cast <int32_t *>(&D.wi_marray );
641
642
if constexpr (std::is_same<T1, int8_t >::value) {
642
643
__imma_m32n8k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
@@ -650,22 +651,22 @@ struct joint_matrix_mad_impl<
650
651
bfloat16>::value) {
651
652
__mma_bf16_m32n8k16_mma_f32 (
652
653
reinterpret_cast <float *>(&D.wi_marray ),
653
- reinterpret_cast <int32_t const *>(&A.wi_marray ),
654
- reinterpret_cast <int32_t const *>(&B.wi_marray ),
655
- reinterpret_cast <float const *>(&C.wi_marray ),
654
+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
655
+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
656
+ reinterpret_cast <const float *>(&C.wi_marray ),
656
657
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
657
658
} else if constexpr (std::is_same<T1, half>::value) {
658
- auto ptrA = reinterpret_cast <int32_t const *>(&A.wi_marray );
659
- auto ptrB = reinterpret_cast <int32_t const *>(&B.wi_marray );
659
+ auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
660
+ auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
660
661
if constexpr (std::is_same<T2, float >::value) {
661
662
__hmma_m32n8k16_mma_f32f32 (
662
663
reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
663
- reinterpret_cast <float const *>(&C.wi_marray ),
664
+ reinterpret_cast <const float *>(&C.wi_marray ),
664
665
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
665
666
} else if constexpr (std::is_same<T2, half>::value) {
666
667
__hmma_m32n8k16_mma_f16f16 (
667
668
reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
668
- reinterpret_cast <int32_t const *>(&C.wi_marray ),
669
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
669
670
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
670
671
}
671
672
}
@@ -677,9 +678,9 @@ struct joint_matrix_mad_impl<
677
678
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
678
679
} else if constexpr (std::is_same<T1, double >::value) {
679
680
__dmma_m8n8k4_mma_f64 (reinterpret_cast <double *>(&D.wi_marray ),
680
- reinterpret_cast <double const *>(&A.wi_marray ),
681
- reinterpret_cast <double const *>(&B.wi_marray ),
682
- reinterpret_cast <double const *>(&C.wi_marray ),
681
+ reinterpret_cast <const double *>(&A.wi_marray ),
682
+ reinterpret_cast <const double *>(&B.wi_marray ),
683
+ reinterpret_cast <const double *>(&C.wi_marray ),
683
684
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
684
685
}
685
686
return D;
@@ -692,13 +693,14 @@ struct joint_matrix_mad_impl<
692
693
namespace experimental {
693
694
namespace matrix {
694
695
695
- template <typename Group, typename S, typename T, matrix_use Use,
696
- size_t NumRows, size_t NumCols, matrix_layout Layout,
697
- access::address_space Space,
698
- std::enable_if_t <std::is_same<S, T>::value ||
699
- (std::is_same<S, precision::tf32>::value &&
700
- std::is_same<T, float >::value),
701
- bool > = true >
696
+ template <
697
+ typename Group, typename S, typename T, matrix_use Use, size_t NumRows,
698
+ size_t NumCols, matrix_layout Layout, access::address_space Space,
699
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value ||
700
+ (std::is_same<S, precision::tf32>::value &&
701
+
702
+ std::is_same<std::remove_const_t <T>, float >::value),
703
+ bool > = true >
702
704
void joint_matrix_load (
703
705
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
704
706
multi_ptr<T, Space> src, size_t stride) {
0 commit comments