Skip to content

Commit 87dbab7

Browse files
committed
fix MI250
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
1 parent b2e2d43 commit 87dbab7

File tree

1 file changed

+40
-45
lines changed

1 file changed

+40
-45
lines changed

csrc/rocm/skinny_gemms.cu

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
275275
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
276276
const scalar_t* __restrict__ A, scalar_t* C,
277277
const int _WvPrGrp, const int CuCount) {
278+
279+
#if defined(__HIP__MI300__)
280+
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
281+
#else
282+
constexpr bool use_mfma = false;
283+
#endif
284+
278285
using scalar8 =
279286
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
280287
using half4 =
@@ -348,10 +355,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
348355
//----------------------------------------------------
349356
for (int i = 0; i < YTILE; i++)
350357
for (int n = 0; n < N; n++)
351-
if constexpr (std::is_same_v<scalar_t, half>)
358+
if constexpr (!use_mfma)
352359
sum[n][i] = 0;
353360
else
354-
sum4[n][i] = {0,0,0,0};
361+
sum4[n][i] = {0, 0, 0, 0};
355362

356363
bigType bigA[N][UNRL];
357364
bigType bigB[YTILE][UNRL];
@@ -412,22 +419,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
412419
for (uint32_t n = 0; n < N; n++) {
413420
#pragma unroll
414421
for (int y=0; y<YTILE; y++) {
415-
if constexpr (std::is_same_v<scalar_t, half>)
422+
if constexpr (!use_mfma)
416423
#pragma unroll
417424
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
418425
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
419-
}
420-
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>)
421-
#if defined(__HIP__MI300__)
426+
}
427+
else
422428
#pragma unroll
423429
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
424-
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
425-
#else
426-
#pragma unroll
427-
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
428-
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
429-
}
430-
#endif
430+
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
431431
}
432432
}
433433
}
@@ -436,7 +436,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
436436
//----------------------------------------------------
437437
// Final reduction step using shuffle
438438
//----------------------------------------------------
439-
if constexpr (std::is_same_v<scalar_t, half>) {
439+
if constexpr (!use_mfma) {
440440
for (int n = 0; n < N; n++) {
441441
for (int y = 0; y < YTILE; y++) {
442442
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
459459
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
460460
}
461461
}
462-
462+
463463
if (threadIdx.x == 63) {
464464
for (int n = 0; n < N; n++) {
465465
for (int i = 0; i < YTILE; i++) {
@@ -468,8 +468,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
468468
}
469469
}
470470
}
471-
}
472-
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
471+
} else {
473472
#pragma unroll
474473
for (int n = 0; n < N; n++) {
475474
#pragma unroll
@@ -536,6 +535,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
536535
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
537536
const scalar_t* __restrict__ A, scalar_t* C,
538537
const int _WvPrGrp, const int CuCount) {
538+
#if defined(__HIP__MI300__)
539+
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
540+
#else
541+
constexpr bool use_mfma = false;
542+
#endif
543+
539544
using scalar8 =
540545
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
541546
using half4 =
@@ -634,10 +639,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
634639
//----------------------------------------------------
635640
for (int i = 0; i < YTILE; i++)
636641
for (int n = 0; n < N; n++)
637-
if constexpr (std::is_same_v<scalar_t, half>)
642+
if constexpr (!use_mfma)
638643
sum[n][i] = 0;
639644
else
640-
sum4[n][i] = {0,0,0,0};
645+
sum4[n][i] = {0, 0, 0, 0};
641646

642647
bigType bigA[N][UNRL];
643648
bigType bigB[YTILE][UNRL];
@@ -700,22 +705,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
700705
// - Remember the accumulation is happening for K-split of 64!
701706
#pragma unroll
702707
for (int y=0; y<YTILE; y++) {
703-
if constexpr (std::is_same_v<scalar_t, half>)
708+
if constexpr (!use_mfma)
704709
#pragma unroll
705-
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
710+
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
706711
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
707712
}
708-
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>)
709-
#if defined(__HIP__MI300__)
713+
else
710714
#pragma unroll
711715
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
712716
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
713-
#else
714-
#pragma unroll
715-
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
716-
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
717-
}
718-
#endif
719717
}
720718
}
721719
}
@@ -724,7 +722,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
724722
//----------------------------------------------------
725723
// Final reduction step using shuffle
726724
//----------------------------------------------------
727-
if constexpr (std::is_same_v<scalar_t, half>) {
725+
if constexpr (!use_mfma) {
728726
for (int n = 0; n < N; n++) {
729727
for (int y = 0; y < YTILE; y++) {
730728
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
@@ -756,8 +754,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
756754
}
757755
}
758756
}
759-
}
760-
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
757+
} else {
761758
#pragma unroll
762759
for (int n = 0; n < N; n++) {
763760
#pragma unroll
@@ -837,6 +834,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
837834
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
838835
const scalar_t* __restrict__ A, scalar_t* C,
839836
const int _WvPrGrp, const int CuCount) {
837+
#if defined(__HIP__MI300__)
838+
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
839+
#else
840+
constexpr bool use_mfma = false;
841+
#endif
842+
840843
using scalar8 =
841844
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
842845
using half4 =
@@ -955,7 +958,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
955958
//----------------------------------------------------
956959
for (int i = 0; i < YTILE; i++)
957960
for (int n = 0; n < N; n++)
958-
if constexpr (std::is_same_v<scalar_t, half>)
961+
if constexpr (!use_mfma)
959962
sum[n][i] = 0;
960963
else
961964
sum4[n][i] = {0,0,0,0};
@@ -1044,22 +1047,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
10441047
// - Remember the accumulation is happening for K-split of 64!
10451048
#pragma unroll
10461049
for (int y=0; y<YTILE; y++) {
1047-
if constexpr (std::is_same_v<scalar_t, half>)
1050+
if constexpr (!use_mfma)
10481051
#pragma unroll
10491052
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
10501053
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
10511054
}
1052-
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>)
1053-
#if defined(__HIP__MI300__)
1055+
else
10541056
#pragma unroll
10551057
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
10561058
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
1057-
#else
1058-
#pragma unroll
1059-
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
1060-
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
1061-
}
1062-
#endif
10631059
}
10641060
}
10651061
}
@@ -1076,7 +1072,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
10761072
//----------------------------------------------------
10771073
// Final reduction step using shuffle
10781074
//----------------------------------------------------
1079-
if constexpr (std::is_same_v<scalar_t, half>) {
1075+
if constexpr (!use_mfma) {
10801076
for (int n = 0; n < N; n++) {
10811077
for (int y = 0; y < YTILE; y++) {
10821078
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
@@ -1108,8 +1104,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
11081104
}
11091105
}
11101106
}
1111-
}
1112-
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
1107+
} else {
11131108
#pragma unroll
11141109
for (int n = 0; n < N; n++) {
11151110
#pragma unroll

0 commit comments

Comments
 (0)