@@ -275,12 +275,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
275
275
wvSplitK_hf_sml_(const int K, const int M, const scalar_t * B,
276
276
const scalar_t * __restrict__ A, scalar_t * C,
277
277
const int _WvPrGrp, const int CuCount) {
278
-
279
- #if defined(__HIP__MI300__)
278
+ #if defined(__HIP__MI300__)
280
279
constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
281
- #else
280
+ #else
282
281
constexpr bool use_mfma = false ;
283
- #endif
282
+ #endif
284
283
285
284
using scalar8 =
286
285
__attribute__ ((__vector_size__ ((A_CHUNK / 2 ) * sizeof (float )))) float ;
@@ -389,7 +388,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
389
388
if (k_ >= K) break ;
390
389
391
390
const scalar_t * B_ = &B[(m + 0 ) * K + k_];
392
- for (int y= 0 ; y< YTILE; y++)
391
+ for (int y = 0 ; y < YTILE; y++)
393
392
bigB[y][k2].h8 = (loadnt ((scalar8*)(&B_[y * K])));
394
393
}
395
394
@@ -418,16 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
418
417
#pragma unroll
419
418
for (uint32_t n = 0 ; n < N; n++) {
420
419
#pragma unroll
421
- for (int y= 0 ; y< YTILE; y++) {
420
+ for (int y = 0 ; y < YTILE; y++) {
422
421
if constexpr (!use_mfma)
423
422
#pragma unroll
424
- for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
423
+ for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
425
424
DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
426
- }
425
+ }
427
426
else
428
427
#pragma unroll
429
428
for (uint32_t b = 0 ; b < A_CHUNK / 4 ; b++)
430
- sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
429
+ sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
430
+ bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
431
431
}
432
432
}
433
433
}
@@ -440,23 +440,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
440
440
for (int n = 0 ; n < N; n++) {
441
441
for (int y = 0 ; y < YTILE; y++) {
442
442
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
443
- : " =v" (sum[n][y])
444
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
443
+ : " =v" (sum[n][y])
444
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
445
445
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
446
- : " =v" (sum[n][y])
447
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
446
+ : " =v" (sum[n][y])
447
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
448
448
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
449
- : " =v" (sum[n][y])
450
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
449
+ : " =v" (sum[n][y])
450
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
451
451
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
452
- : " =v" (sum[n][y])
453
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
452
+ : " =v" (sum[n][y])
453
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
454
454
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
455
- : " =v" (sum[n][y])
456
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
455
+ : " =v" (sum[n][y])
456
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
457
457
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
458
- : " =v" (sum[n][y])
459
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
458
+ : " =v" (sum[n][y])
459
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
460
460
}
461
461
}
462
462
@@ -473,9 +473,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
473
473
for (int n = 0 ; n < N; n++) {
474
474
#pragma unroll
475
475
for (int y = 0 ; y < YTILE; y++) {
476
- // float accm1 = 0;
477
- // for (int i=0; i<64; i++)
478
- // accm1 += __shfl(sum4[n][y][i%4], i);
476
+ // float accm1 = 0;
477
+ // for (int i=0; i<64; i++)
478
+ // accm1 += __shfl(sum4[n][y][i%4], i);
479
479
float accm = sum4[n][y][0 ];
480
480
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
481
481
: " =v" (accm)
@@ -535,11 +535,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
535
535
wvSplitK_hf_(const int K, const int M, const scalar_t * B,
536
536
const scalar_t * __restrict__ A, scalar_t * C,
537
537
const int _WvPrGrp, const int CuCount) {
538
- #if defined(__HIP__MI300__)
538
+ #if defined(__HIP__MI300__)
539
539
constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
540
- #else
540
+ #else
541
541
constexpr bool use_mfma = false ;
542
- #endif
542
+ #endif
543
543
544
544
using scalar8 =
545
545
__attribute__ ((__vector_size__ ((A_CHUNK / 2 ) * sizeof (float )))) float ;
@@ -672,7 +672,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
672
672
if (k_ >= K) break ;
673
673
674
674
const scalar_t * B_ = &B[(m + 0 ) * K + k_];
675
- for (int b= 0 ; b< YTILE; b++)
675
+ for (int b = 0 ; b < YTILE; b++)
676
676
bigB[b][k2].h8 = (loadnt ((scalar8*)(&B_[b * K])));
677
677
}
678
678
@@ -704,16 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
704
704
// Do the matrix multiplication of activation and weight matrix
705
705
// - Remember the accumulation is happening for K-split of 64!
706
706
#pragma unroll
707
- for (int y= 0 ; y< YTILE; y++) {
707
+ for (int y = 0 ; y < YTILE; y++) {
708
708
if constexpr (!use_mfma)
709
709
#pragma unroll
710
710
for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
711
711
DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
712
- }
712
+ }
713
713
else
714
714
#pragma unroll
715
715
for (uint32_t b = 0 ; b < A_CHUNK / 4 ; b++)
716
- sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
716
+ sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
717
+ bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
717
718
}
718
719
}
719
720
}
@@ -726,23 +727,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
726
727
for (int n = 0 ; n < N; n++) {
727
728
for (int y = 0 ; y < YTILE; y++) {
728
729
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
729
- : " =v" (sum[n][y])
730
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
730
+ : " =v" (sum[n][y])
731
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
731
732
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
732
- : " =v" (sum[n][y])
733
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
733
+ : " =v" (sum[n][y])
734
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
734
735
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
735
- : " =v" (sum[n][y])
736
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
736
+ : " =v" (sum[n][y])
737
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
737
738
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
738
- : " =v" (sum[n][y])
739
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
739
+ : " =v" (sum[n][y])
740
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
740
741
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
741
- : " =v" (sum[n][y])
742
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
742
+ : " =v" (sum[n][y])
743
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
743
744
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
744
- : " =v" (sum[n][y])
745
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
745
+ : " =v" (sum[n][y])
746
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
746
747
}
747
748
}
748
749
@@ -759,9 +760,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
759
760
for (int n = 0 ; n < N; n++) {
760
761
#pragma unroll
761
762
for (int y = 0 ; y < YTILE; y++) {
762
- // float accm1 = 0;
763
- // for (int i=0; i<64; i++)
764
- // accm1 += __shfl(sum4[n][y][i%4], i);
763
+ // float accm1 = 0;
764
+ // for (int i=0; i<64; i++)
765
+ // accm1 += __shfl(sum4[n][y][i%4], i);
765
766
766
767
float accm = sum4[n][y][0 ];
767
768
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
@@ -834,11 +835,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
834
835
wvSplitK_hf_big_(const int K, const int M, const scalar_t * B,
835
836
const scalar_t * __restrict__ A, scalar_t * C,
836
837
const int _WvPrGrp, const int CuCount) {
837
- #if defined(__HIP__MI300__)
838
+ #if defined(__HIP__MI300__)
838
839
constexpr bool use_mfma = (std::is_same_v<scalar_t , __hip_bfloat16>);
839
- #else
840
+ #else
840
841
constexpr bool use_mfma = false ;
841
- #endif
842
+ #endif
842
843
843
844
using scalar8 =
844
845
__attribute__ ((__vector_size__ ((A_CHUNK / 2 ) * sizeof (float )))) float ;
@@ -961,7 +962,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
961
962
if constexpr (!use_mfma)
962
963
sum[n][i] = 0 ;
963
964
else
964
- sum4[n][i] = {0 ,0 , 0 , 0 };
965
+ sum4[n][i] = {0 , 0 , 0 , 0 };
965
966
966
967
bigType bigA[N][UNRL];
967
968
bigType bigB[YTILE][UNRL];
@@ -1010,7 +1011,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1010
1011
if (k_ >= K) break ;
1011
1012
1012
1013
const scalar_t * B_ = &B[(m + 0 ) * K + k_];
1013
- for (int b= 0 ; b< YTILE; b++)
1014
+ for (int b = 0 ; b < YTILE; b++)
1014
1015
bigB[b][k2].h8 = (loadnt ((scalar8*)(&B_[b * K])));
1015
1016
}
1016
1017
@@ -1046,16 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1046
1047
// Do the matrix multiplication of activation and weight matrix
1047
1048
// - Remember the accumulation is happening for K-split of 64!
1048
1049
#pragma unroll
1049
- for (int y= 0 ; y< YTILE; y++) {
1050
+ for (int y = 0 ; y < YTILE; y++) {
1050
1051
if constexpr (!use_mfma)
1051
1052
#pragma unroll
1052
1053
for (uint32_t b = 0 ; b < A_CHUNK / 2 ; b++) {
1053
1054
DOT2C (sum[n][y], bigA[n][k2].f [b], bigB[y][k2].f [b])
1054
- }
1055
- else
1055
+ }
1056
+ else
1056
1057
#pragma unroll
1057
1058
for (uint32_t b = 0 ; b < A_CHUNK / 4 ; b++)
1058
- sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
1059
+ sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
1060
+ bigA[n][k2].h4 [b], bigB[y][k2].h4 [b], sum4[n][y], 0 , 0 , 0 );
1059
1061
}
1060
1062
}
1061
1063
}
@@ -1076,23 +1078,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1076
1078
for (int n = 0 ; n < N; n++) {
1077
1079
for (int y = 0 ; y < YTILE; y++) {
1078
1080
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
1079
- : " =v" (sum[n][y])
1080
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1081
+ : " =v" (sum[n][y])
1082
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1081
1083
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
1082
- : " =v" (sum[n][y])
1083
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1084
+ : " =v" (sum[n][y])
1085
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1084
1086
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
1085
- : " =v" (sum[n][y])
1086
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1087
+ : " =v" (sum[n][y])
1088
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1087
1089
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
1088
- : " =v" (sum[n][y])
1089
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1090
+ : " =v" (sum[n][y])
1091
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1090
1092
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
1091
- : " =v" (sum[n][y])
1092
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1093
+ : " =v" (sum[n][y])
1094
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1093
1095
asm (" s_nop 0\n\t v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
1094
- : " =v" (sum[n][y])
1095
- : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1096
+ : " =v" (sum[n][y])
1097
+ : " 0" (sum[n][y]), " v" (sum[n][y]), " v" (sum[n][y]));
1096
1098
}
1097
1099
}
1098
1100
@@ -1148,7 +1150,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1148
1150
}
1149
1151
}
1150
1152
1151
-
1152
1153
m += CuCount * _WvPrGrp * YTILE;
1153
1154
kBase = 0 ;
1154
1155
0 commit comments