@@ -275,6 +275,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
275
275
wvSplitK_hf_sml_(const int K, const int M, const scalar_t * B,
276
276
const scalar_t * __restrict__ A, scalar_t * C,
277
277
const int _WvPrGrp, const int CuCount) {
278
+
279
+ #if defined(__HIP__MI300__)
280
+ constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
281
+ #else
282
+ constexpr bool use_mfma = false ;
283
+ #endif
284
+
278
285
using scalar8 =
279
286
__attribute__ ((__vector_size__ ((A_CHUNK / 2 ) * sizeof (float )))) float ;
280
287
using half4 =
@@ -348,10 +355,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
348
355
// ----------------------------------------------------
349
356
for (int i = 0 ; i < YTILE; i++)
350
357
for (int n = 0 ; n < N; n++)
351
- if constexpr (std::is_same_v< scalar_t , half> )
358
+ if constexpr (!use_mfma )
352
359
sum[n][i] = 0 ;
353
360
else
354
- sum4[n][i] = {0 ,0 , 0 , 0 };
361
+ sum4[n][i] = {0 , 0 , 0 , 0 };
355
362
356
363
bigType bigA[N][UNRL];
357
364
bigType bigB[YTILE][UNRL];
@@ -412,22 +419,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
412
419
for (uint32_t n = 0 ; n < N; n++) {
413
420
#pragma unroll
414
421
for (int y=0 ; y<YTILE; y++) {
415
- if constexpr (std::is_same_v< scalar_t , half> )
422
+ if constexpr (!use_mfma )
416
423
#pragma unroll
417
424
for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
418
425
DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
419
- }
420
- if constexpr (std::is_same_v<scalar_t , __hip_bfloat16>)
421
- #if defined(__HIP__MI300__)
426
+ }
427
+ else
422
428
#pragma unroll
423
429
for (uint32_t b = 0 ; b < A_CHUNK / 4 ; b++)
424
- sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
425
- #else
426
- #pragma unroll
427
- for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
428
- DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
429
- }
430
- #endif
430
+ sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
431
431
}
432
432
}
433
433
}
@@ -436,7 +436,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
436
436
// ----------------------------------------------------
437
437
// Final reduction step using shuffle
438
438
// ----------------------------------------------------
439
- if constexpr (std::is_same_v< scalar_t , half> ) {
439
+ if constexpr (!use_mfma ) {
440
440
for (int n = 0 ; n < N; n++) {
441
441
for (int y = 0 ; y < YTILE; y++) {
442
442
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
459
459
: " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
460
460
}
461
461
}
462
-
462
+
463
463
if (threadIdx .x == 63 ) {
464
464
for (int n = 0 ; n < N; n++) {
465
465
for (int i = 0 ; i < YTILE; i++) {
@@ -468,8 +468,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
468
468
}
469
469
}
470
470
}
471
- }
472
- if constexpr (std::is_same_v<scalar_t , __hip_bfloat16>) {
471
+ } else {
473
472
#pragma unroll
474
473
for (int n = 0 ; n < N; n++) {
475
474
#pragma unroll
@@ -536,6 +535,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
536
535
wvSplitK_hf_(const int K, const int M, const scalar_t * B,
537
536
const scalar_t * __restrict__ A, scalar_t * C,
538
537
const int _WvPrGrp, const int CuCount) {
538
+ #if defined(__HIP__MI300__)
539
+ constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
540
+ #else
541
+ constexpr bool use_mfma = false ;
542
+ #endif
543
+
539
544
using scalar8 =
540
545
__attribute__ ((__vector_size__ ((A_CHUNK / 2 ) * sizeof (float )))) float ;
541
546
using half4 =
@@ -634,10 +639,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
634
639
// ----------------------------------------------------
635
640
for (int i = 0 ; i < YTILE; i++)
636
641
for (int n = 0 ; n < N; n++)
637
- if constexpr (std::is_same_v< scalar_t , half> )
642
+ if constexpr (!use_mfma )
638
643
sum[n][i] = 0 ;
639
644
else
640
- sum4[n][i] = {0 ,0 , 0 , 0 };
645
+ sum4[n][i] = {0 , 0 , 0 , 0 };
641
646
642
647
bigType bigA[N][UNRL];
643
648
bigType bigB[YTILE][UNRL];
@@ -700,22 +705,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
700
705
// - Remember the accumulation is happening for K-split of 64!
701
706
#pragma unroll
702
707
for (int y=0 ; y<YTILE; y++) {
703
- if constexpr (std::is_same_v< scalar_t , half> )
708
+ if constexpr (!use_mfma )
704
709
#pragma unroll
705
- for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
710
+ for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
706
711
DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
707
712
}
708
- if constexpr (std::is_same_v<scalar_t , __hip_bfloat16>)
709
- #if defined(__HIP__MI300__)
713
+ else
710
714
#pragma unroll
711
715
for (uint32_t b = 0 ; b < A_CHUNK / 4 ; b++)
712
716
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
713
- #else
714
- #pragma unroll
715
- for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
716
- DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
717
- }
718
- #endif
719
717
}
720
718
}
721
719
}
@@ -724,7 +722,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
724
722
// ----------------------------------------------------
725
723
// Final reduction step using shuffle
726
724
// ----------------------------------------------------
727
- if constexpr (std::is_same_v< scalar_t , half> ) {
725
+ if constexpr (!use_mfma ) {
728
726
for (int n = 0 ; n < N; n++) {
729
727
for (int y = 0 ; y < YTILE; y++) {
730
728
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
@@ -756,8 +754,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
756
754
}
757
755
}
758
756
}
759
- }
760
- if constexpr (std::is_same_v<scalar_t , __hip_bfloat16>) {
757
+ } else {
761
758
#pragma unroll
762
759
for (int n = 0 ; n < N; n++) {
763
760
#pragma unroll
@@ -837,6 +834,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
837
834
wvSplitK_hf_big_(const int K, const int M, const scalar_t * B,
838
835
const scalar_t * __restrict__ A, scalar_t * C,
839
836
const int _WvPrGrp, const int CuCount) {
837
+ #if defined(__HIP__MI300__)
838
+ constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
839
+ #else
840
+ constexpr bool use_mfma = false ;
841
+ #endif
842
+
840
843
using scalar8 =
841
844
__attribute__ ((__vector_size__ ((A_CHUNK / 2 ) * sizeof (float )))) float ;
842
845
using half4 =
@@ -955,7 +958,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
955
958
// ----------------------------------------------------
956
959
for (int i = 0 ; i < YTILE; i++)
957
960
for (int n = 0 ; n < N; n++)
958
- if constexpr (std::is_same_v< scalar_t , half> )
961
+ if constexpr (!use_mfma )
959
962
sum[n][i] = 0 ;
960
963
else
961
964
sum4[n][i] = {0 ,0 ,0 ,0 };
@@ -1044,22 +1047,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1044
1047
// - Remember the accumulation is happening for K-split of 64!
1045
1048
#pragma unroll
1046
1049
for (int y=0 ; y<YTILE; y++) {
1047
- if constexpr (std::is_same_v< scalar_t , half> )
1050
+ if constexpr (!use_mfma )
1048
1051
#pragma unroll
1049
1052
for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
1050
1053
DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
1051
1054
}
1052
- if constexpr (std::is_same_v<scalar_t , __hip_bfloat16>)
1053
- #if defined(__HIP__MI300__)
1055
+ else
1054
1056
#pragma unroll
1055
1057
for (uint32_t b = 0 ; b < A_CHUNK / 4 ; b++)
1056
1058
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
1057
- #else
1058
- #pragma unroll
1059
- for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
1060
- DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
1061
- }
1062
- #endif
1063
1059
}
1064
1060
}
1065
1061
}
@@ -1076,7 +1072,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1076
1072
// ----------------------------------------------------
1077
1073
// Final reduction step using shuffle
1078
1074
// ----------------------------------------------------
1079
- if constexpr (std::is_same_v< scalar_t , half> ) {
1075
+ if constexpr (!use_mfma ) {
1080
1076
for (int n = 0 ; n < N; n++) {
1081
1077
for (int y = 0 ; y < YTILE; y++) {
1082
1078
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
@@ -1108,8 +1104,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1108
1104
}
1109
1105
}
1110
1106
}
1111
- }
1112
- if constexpr (std::is_same_v<scalar_t , __hip_bfloat16>) {
1107
+ } else {
1113
1108
#pragma unroll
1114
1109
for (int n = 0 ; n < N; n++) {
1115
1110
#pragma unroll
0 commit comments