From b2e2d432b3ee17a4a01f32aba5193b2460c4981b Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Wed, 23 Apr 2025 04:58:58 +0000 Subject: [PATCH 1/6] mfma optimization of wvspltk solution for bf16 skinny GEMMs Signed-off-by: Hashem Hashemi --- csrc/rocm/skinny_gemms.cu | 443 +++++++++++++++++++++++--------------- 1 file changed, 267 insertions(+), 176 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 29dbbe8e35e8..254900643c1d 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -277,11 +277,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const int _WvPrGrp, const int CuCount) { using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -318,6 +321,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -343,7 +347,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (std::is_same_v) + sum[n][i] = 0; + else + sum4[n][i] = {0,0,0,0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -374,24 +382,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int y=0; y 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y=0; y) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + if constexpr (std::is_same_v) +#if defined(__HIP__MI300__) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); +#else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } +#endif } } } @@ -453,37 +436,85 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + if constexpr (std::is_same_v) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } } } - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + //float accm1 = 0; + //for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } - m += CuCount * _WvPrGrp * YTILE; } } @@ -507,11 +538,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const int _WvPrGrp, const int CuCount) { using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -573,6 +607,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -598,7 +633,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (std::is_same_v) + sum[n][i] = 0; + else + sum4[n][i] = {0,0,0,0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -628,24 +667,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b=0; b 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y=0; y) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + if constexpr (std::is_same_v) +#if defined(__HIP__MI300__) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); +#else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } +#endif } } } @@ -710,34 +724,83 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + if constexpr (std::is_same_v) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } + if constexpr (std::is_same_v) { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + //float accm1 = 0; + //for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -776,12 +839,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const int _WvPrGrp, const int CuCount) { using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -857,6 +922,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) kFit = min(kFit, K); float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -888,7 +954,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (std::is_same_v) + sum[n][i] = 0; + else + sum4[n][i] = {0,0,0,0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -937,24 +1007,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b=0; b 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y=0; y) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + if constexpr (std::is_same_v) +#if defined(__HIP__MI300__) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); +#else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } +#endif } } } @@ -1031,38 +1076,84 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + if constexpr (std::is_same_v) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" : "=v"(sum[n][y]) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } + if constexpr (std::is_same_v) { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } + m += CuCount * _WvPrGrp * YTILE; kBase = 0; @@ -1597,4 +1688,4 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, } }); }); -} \ No newline at end of file +} From 87dbab7e7e54b22801d5d1793a506948b0d3b847 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Wed, 23 Apr 2025 08:40:04 +0000 Subject: [PATCH 2/6] fix MI250 Signed-off-by: Hashem Hashemi --- csrc/rocm/skinny_gemms.cu | 85 ++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 45 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 254900643c1d..95e7b234606c 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -275,6 +275,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + +#if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); +#else + constexpr bool use_mfma = false; +#endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; using half4 = @@ -348,10 +355,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- for (int i = 0; i < YTILE; i++) for (int n = 0; n < N; n++) - if constexpr (std::is_same_v) + if constexpr (!use_mfma) sum[n][i] = 0; else - sum4[n][i] = {0,0,0,0}; + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -412,22 +419,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t n = 0; n < N; n++) { #pragma unroll for (int y=0; y) + if constexpr (!use_mfma) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) - } - if constexpr (std::is_same_v) -#if defined(__HIP__MI300__) + } + else #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); -#else - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) - } -#endif + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -436,7 +436,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - if constexpr (std::is_same_v) { + if constexpr (!use_mfma) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " @@ -459,7 +459,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); } } - + if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { @@ -468,8 +468,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } - } - if constexpr (std::is_same_v) { + } else { #pragma unroll for (int n = 0; n < N; n++) { #pragma unroll @@ -536,6 +535,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { +#if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); +#else + constexpr bool use_mfma = false; +#endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; using half4 = @@ -634,10 +639,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- for (int i = 0; i < YTILE; i++) for (int n = 0; n < N; n++) - if constexpr (std::is_same_v) + if constexpr (!use_mfma) sum[n][i] = 0; else - sum4[n][i] = {0,0,0,0}; + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -700,22 +705,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (int y=0; y) + if constexpr (!use_mfma) #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } - if constexpr (std::is_same_v) -#if defined(__HIP__MI300__) + else #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); -#else - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) - } -#endif } } } @@ -724,7 +722,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - if constexpr (std::is_same_v) { + if constexpr (!use_mfma) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " @@ -756,8 +754,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } - } - if constexpr (std::is_same_v) { + } else { #pragma unroll for (int n = 0; n < N; n++) { #pragma unroll @@ -837,6 +834,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { +#if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); +#else + constexpr bool use_mfma = false; +#endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; using half4 = @@ -955,7 +958,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- for (int i = 0; i < YTILE; i++) for (int n = 0; n < N; n++) - if constexpr (std::is_same_v) + if constexpr (!use_mfma) sum[n][i] = 0; else sum4[n][i] = {0,0,0,0}; @@ -1044,22 +1047,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (int y=0; y) + if constexpr (!use_mfma) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } - if constexpr (std::is_same_v) -#if defined(__HIP__MI300__) + else #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); -#else - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) - } -#endif } } } @@ -1076,7 +1072,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - if constexpr (std::is_same_v) { + if constexpr (!use_mfma) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " @@ -1108,8 +1104,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } - } - if constexpr (std::is_same_v) { + } else { #pragma unroll for (int n = 0; n < N; n++) { #pragma unroll From 295d4d5dee7f8752106d88812126e8ceb8548d30 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Wed, 23 Apr 2025 17:56:52 +0000 Subject: [PATCH 3/6] lint fix Signed-off-by: Hashem Hashemi --- csrc/rocm/skinny_gemms.cu | 137 +++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 68 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 95e7b234606c..9a89fb695834 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -275,12 +275,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - -#if defined(__HIP__MI300__) + #if defined(__HIP__MI300__) constexpr bool use_mfma = (std::is_same_v); -#else + #else constexpr bool use_mfma = false; -#endif + #endif using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -389,7 +388,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - for (int y=0; y); -#else + #else constexpr bool use_mfma = false; -#endif + #endif using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -672,7 +672,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - for (int b=0; b); -#else + #else constexpr bool use_mfma = false; -#endif + #endif using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -961,7 +962,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (!use_mfma) sum[n][i] = 0; else - sum4[n][i] = {0,0,0,0}; + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -1010,7 +1011,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - for (int b=0; b Date: Fri, 2 May 2025 21:29:02 +0000 Subject: [PATCH 4/6] fix llmm on k size 6114 Signed-off-by: charlifu --- csrc/rocm/skinny_gemms.cu | 28 +++++++++---------- .../quantization/test_rocm_skinny_gemms.py | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 75b0fd61e4a1..b3717892db78 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / num_warps; - const int qthreadid = threadid % num_warps; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; float acc[NUM_A_ROWS_PER_BLOCK]; @@ -142,15 +142,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, // rowA_elem4[i] holds 8 * half numbers seen as a single float4. rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; } - colB_elem4x = bf4[threadid * 4 + 0]; - colB_elem4y = bf4[threadid * 4 + 1]; - colB_elem4z = bf4[threadid * 4 + 2]; - colB_elem4w = bf4[threadid * 4 + 3]; - scalar2_t Af2; - [[maybe_unused]] scalar2_t Bf2; float2 S; auto Ah2ptr = reinterpret_cast(&rowA_elem4); @@ -193,12 +191,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, if (qwarpid < NUM_A_ROWS_PER_BLOCK) { acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; - for (int mask = num_warps / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); } - float oval2 = __shfl_xor(acc[qwarpid], num_warps); + float oval2 = __shfl_xor(acc[qwarpid], 16); - if (lane % (num_warps * 2) == 0) { + if (lane % 32 == 0) { oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; } @@ -222,9 +221,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle // operations. const int NUM_THREADS = - K * 2 / 16 % WARP_SIZE == 0 - ? K * 2 / 16 - : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + max(rows_per_block * 16, + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE)); int NUM_BLOCKS = M / rows_per_block; diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 622079c39445..76d33169081a 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,7 +8,7 @@ DTYPES = [torch.bfloat16, torch.float16] M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 N = [1, 2, 3, 4] SEEDS = [0] From e796ad956a9ecbf44230288d45064c3a3fac718d Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 7 May 2025 14:50:45 +0000 Subject: [PATCH 5/6] enable skinny gemm for bs4 Signed-off-by: charlifu --- vllm/model_executor/layers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index adb966c4b1c0..751b86787c7b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor, m = weight.shape[0] cu_count = current_platform.get_cu_count() - if m > 8 and 0 < n < 4: + if m > 8 and 0 < n <= 4: out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: From c558b0311de57f41091328dc284c2c294468eb10 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 7 May 2025 19:56:39 +0000 Subject: [PATCH 6/6] add cache to on_mi250_mi300 Signed-off-by: charlifu --- vllm/platforms/rocm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ff63f9656c01..b0b037b69da7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +@cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])