diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 72d2820f2aab..b3717892db78 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / num_warps; - const int qthreadid = threadid % num_warps; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; float acc[NUM_A_ROWS_PER_BLOCK]; @@ -142,15 +142,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, // rowA_elem4[i] holds 8 * half numbers seen as a single float4. rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; } - colB_elem4x = bf4[threadid * 4 + 0]; - colB_elem4y = bf4[threadid * 4 + 1]; - colB_elem4z = bf4[threadid * 4 + 2]; - colB_elem4w = bf4[threadid * 4 + 3]; - scalar2_t Af2; - [[maybe_unused]] scalar2_t Bf2; float2 S; auto Ah2ptr = reinterpret_cast(&rowA_elem4); @@ -193,12 +191,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, if (qwarpid < NUM_A_ROWS_PER_BLOCK) { acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; - for (int mask = num_warps / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); } - float oval2 = __shfl_xor(acc[qwarpid], num_warps); + float oval2 = __shfl_xor(acc[qwarpid], 16); - if (lane % (num_warps * 2) == 0) { + if (lane % 32 == 0) { oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; } @@ -222,9 +221,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle // operations. const int NUM_THREADS = - K * 2 / 16 % WARP_SIZE == 0 - ? K * 2 / 16 - : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + max(rows_per_block * 16, + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE)); int NUM_BLOCKS = M / rows_per_block; @@ -275,13 +275,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -318,6 +327,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -343,7 +353,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -374,24 +388,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -419,32 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t n = 0; n < N; n++) { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]) - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -453,37 +436,84 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } - m += CuCount * _WvPrGrp * YTILE; } } @@ -505,13 +535,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -573,6 +612,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -598,7 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -628,24 +672,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -676,32 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -710,34 +723,82 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -774,14 +835,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -857,6 +926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) kFit = min(kFit, K); float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -888,7 +958,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -937,24 +1011,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -989,32 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -1031,34 +1074,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 622079c39445..76d33169081a 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,7 +8,7 @@ DTYPES = [torch.bfloat16, torch.float16] M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 N = [1, 2, 3, 4] SEEDS = [0] diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index adb966c4b1c0..751b86787c7b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor, m = weight.shape[0] cu_count = current_platform.get_cu_count() - if m > 8 and 0 < n < 4: + if m > 8 and 0 < n <= 4: out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ff63f9656c01..b0b037b69da7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +@cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])