From 1d907c8402a9d219879491e0ceabb1f28b3ebcef Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 7 Apr 2025 17:34:49 +0000 Subject: [PATCH 01/13] Adds bf16 support to wvSpltk solution for sknny GEMMs. --- csrc/rocm/custom.cu | 7 +- csrc/rocm/custom_kernels.cu | 715 +++++++++-------------- csrc/rocm/torch_bindings.cpp | 2 +- vllm/_custom_ops.py | 4 +- vllm/model_executor/layers/tuned_gemm.py | 7 +- 5 files changed, 300 insertions(+), 435 deletions(-) diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu index c799dd273dae..1d39d5dd8910 100644 --- a/csrc/rocm/custom.cu +++ b/csrc/rocm/custom.cu @@ -37,14 +37,15 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, } void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, - const int N, cudaStream_t stream, const int CuCount); + const int N, const int Itp_in, cudaStream_t stream, const int CuCount); void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount) { + const int64_t N_in, const int64_t Itp_in, const int64_t CuCount) { auto M = in_a.size(0); auto K = in_a.size(1); int N = N_in; - wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + int Itp = Itp_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, Itp, at::cuda::getCurrentCUDAStream(), CuCount); } diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 2d4a68fe3e7b..f346d2a00c38 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -329,7 +329,6 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, ///////////////////////////////////////////// -#define DTYPE half /*__device__ __forceinline__ int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; @@ -361,7 +360,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, }*/ #if defined(__HIP__MI300__) // TODO: Add NAVI support -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, @@ -407,14 +406,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = {0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; + bigType bigB[YTILE][UNRL]; // Fetch the weight matrix from memory! for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { @@ -423,16 +415,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; - const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); + for (int y=0; y< YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * Kp / 2]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -455,15 +440,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; - float aV[A_CHUNK * 2]; for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { - sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); - if (YTILE >= 2) - sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + for (uint32_t y = 0; y < YTILE; y++) + sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][0], 0, 0, 0); } } } @@ -538,7 +520,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300__) TODO: Add NAVI support -template +template __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, @@ -550,7 +532,7 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, #endif // defined(__HIP__MI300__) TODO: Add NAVI support #if defined(__HIP__MI300__) // TODO: Add NAVI support -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, @@ -595,14 +577,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = {0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; + bigType bigB[YTILE][UNRL]; // Fetch the weight matrix from memory! for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { @@ -613,14 +588,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K / 2) break; const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * Kp / 2]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -643,15 +612,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; - float aV[A_CHUNK * 2]; for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { - sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); - if (YTILE >= 2) - sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + for (int y = 0; y < YTILE; y++) + sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][0], 0, 0, 0); } } } @@ -726,7 +692,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300__) TODO: Add NAVI support -template +template __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, @@ -738,18 +704,21 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; half8 h8; }; @@ -760,7 +729,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ DTYPE s[1024 * 32]; //---------------------------------------------------- // Fetch the activation matrix to LDS @@ -793,60 +762,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[M][YTILE]; + half8 sum4[M][YTILE]; - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- for (int i = 0; i < YTILE; i++) for (int m = 0; m < M; m++) sum[m][i] = 0; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + bigType bigB[YTILE][UNRL]; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -855,18 +778,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + const DTYPE* B_ = &B[(n + 0) * K + k_]; + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -877,7 +791,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { // if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); @@ -897,43 +810,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), "v"(bigB[y][k2].f[b])); } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" + : "=v"(sum4[m][y]) + : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); + + } } } } @@ -941,7 +833,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { + if constexpr (std::is_same_v) { + for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " : "=v"(sum[m][y]) @@ -962,14 +855,53 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) : "=v"(sum[m][y]) : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); } - } - if (threadIdx.x == 63) { + } + if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); C[n + i + m * N] = __float2half(sum[m][i]); } } + } + } + if constexpr (std::is_same_v) { + #pragma unroll + for (int m = 0; m < M; m++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + //for (int i=0; i<64; i++) + // accm += __shfl(sum[m][y][i%4], i); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + accm += __shfl_down(accm, 32); + accm += __shfl_down(accm, 16); + + sum4[m][y][0] = accm; + } + } + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } + } + } } n += CuCount * _WvPrGrp * YTILE; @@ -986,7 +918,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -996,18 +928,21 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; half8 h8; }; @@ -1018,7 +953,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ DTYPE s[1024 * 32]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -1075,60 +1010,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; float sum[M][YTILE]; + half8 sum4[M][YTILE]; - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- + bigType bigB[YTILE][UNRL]; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -1137,18 +1030,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + const DTYPE* B_ = &B[(n + 0) * K + k_]; + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1159,7 +1043,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); @@ -1170,52 +1053,31 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), "v"(bigB[y][k2].f[b])); } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" + : "=v"(sum4[m][y]) + : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); + + } } } } @@ -1223,7 +1085,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { + if constexpr (std::is_same_v) { + for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " : "=v"(sum[m][y]) @@ -1244,14 +1107,53 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) : "=v"(sum[m][y]) : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } } - - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { + #pragma unroll + for (int m = 0; m < M; m++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + //for (int i=0; i<64; i++) + // accm += __shfl(sum[m][y][i%4], i); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + accm += __shfl_down(accm, 32); + accm += __shfl_down(accm, 16); + + sum4[m][y][0] = accm; + } + } + if (threadIdx.x == 0) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); } } + } } n += CuCount * _WvPrGrp * YTILE; @@ -1269,7 +1171,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -1279,51 +1181,35 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; half8 h8; }; - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ DTYPE s[1024 * 32]; - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- uint32_t commitColumn[YTILE]; for (uint32_t i = 0; i < YTILE; i++) { commitColumn[i] = 1; } - // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - // Check whether there will be fragmenation! - // This will happen only for the last wave! if (n < N && (n + YTILE) >= N) { uint32_t startColumn = N - YTILE; for (uint32_t i = 0; i < (n - startColumn); i++) { @@ -1332,30 +1218,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) n = startColumn; } - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- #define PCML #ifndef PCML for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - if (k_in >= min(K * M, 32 * 1024)) break; - - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); #endif @@ -1373,22 +1242,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) kFit = min(kFit, K); float sum[M][YTILE]; + half8 sum4[M][YTILE]; - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- #ifdef PCML int YW = (YTILE * _WvPrGrp); uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); @@ -1396,45 +1251,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else while (n < N) { #endif - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - bigType bigB9[UNRL]; - bigType bigB10[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- + bigType bigB[YTILE][UNRL]; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { #ifdef PCML if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS @@ -1462,18 +1287,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + const DTYPE* B_ = &B[(n + 0) * K + k_]; + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = loadnt((half8*)(&B_[y * K])); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1503,48 +1319,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), "v"(bigB[y][k2].f[b])); } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" + : "=v"(sum4[m][y]) + : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); + + } } } } @@ -1560,7 +1355,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { + if constexpr (std::is_same_v) { + for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " : "=v"(sum[m][y]) @@ -1581,14 +1377,53 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) : "=v"(sum[m][y]) : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } } - - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { + #pragma unroll + for (int m = 0; m < M; m++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + //for (int i=0; i<64; i++) + // accm += __shfl(sum[m][y][i%4], i); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + accm += __shfl_down(accm, 32); + accm += __shfl_down(accm, 16); + + sum4[m][y][0] = accm; + } + } + if (threadIdx.x == 0) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); } } + } } n += CuCount * _WvPrGrp * YTILE; @@ -1606,7 +1441,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -1644,52 +1479,78 @@ int mindiv(int N, int div1, int div2) { } void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { + const int K_in, const int N_in, const int Itp_in, + cudaStream_t stream, const int CuCount = 0) { dim3 grid(CuCount); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); - -#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ +#define WVSPLTK(_DTYPE, _WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ _N) \ { \ dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + wvSpltK_hf_sml_<_DTYPE, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + wvSpltK_hf_<_DTYPE, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + wvSpltK_hf_big_<_DTYPE, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } \ } - switch (N_in) { + if (Itp_in == 0) { // fp16 + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + + switch (N_in) { + case 1: + WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + } + else if (Itp_in == 1) {// bf16 + __hip_bfloat16* af4 = reinterpret_cast<__hip_bfloat16*>(in_a); + const __hip_bfloat16* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast<__hip_bfloat16*>(out_c); + + switch (N_in) { case 1: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 1) break; case 2: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 2) break; case 3: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 3) break; case 4: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 4) break; default: throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + "," + std::to_string(K_in) + "," + std::to_string(N_in)); + } } cudaError_t err = cudaGetLastError(); @@ -1715,12 +1576,12 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + wvSpltKQ_hf_sml_ \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + wvSpltKQ_hf_ \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } \ @@ -1728,16 +1589,16 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, switch (N_in) { case 1: - WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 1) break; case 2: - WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 2) break; case 3: - WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 3) break; case 4: - WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 4) break; default: throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index fab6a6942054..22f01d72ba8f 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -42,7 +42,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); rocm_ops.def( "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," - " int CuCount) -> ()"); + " int Itp_in, int CuCount) -> ()"); rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); rocm_ops.def( "wvSpltKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1099f0c5c72c..d240cbc5cfa7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1417,8 +1417,8 @@ def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, - cu_count: int) -> None: - torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) + Itp: int, cu_count: int) -> None: + torch.ops._rocm_C.wvSpltK(a, b, out, N, Itp, cu_count) def wvSpltKQ(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 2f26bf5c365b..c58927439373 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -76,14 +76,17 @@ def query_sol(self, m, n, k, bias, dtype): def apply_skinny(self, m, n, k, inp_view, weights): if not self.use_skinny: return None - if inp_view.dtype != torch.float16 or k % 8 != 0: + if (inp_view.dtype != torch.float16) and (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: return None if m > 8 and 0 < n <= 4: out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, device='cuda') - ops.wvSpltK(weights, inp_view, out, n, self.cu_count) + Itp = 1 #default bfloat16 + if out_dtype == torch.float16: + Itp = 0 + ops.wvSpltK(weights, inp_view, out, n, Itp, self.cu_count) return out elif m % 4 == 0 and n == 1 and k <= 8192: out = torch.empty(inp_view.shape[0], From 96bddb76b2526dbbc9f2ab7003b64f96a2402bd1 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 7 Apr 2025 21:06:41 +0000 Subject: [PATCH 02/13] typo fix --- vllm/model_executor/layers/tuned_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index c58927439373..7e76f4bbab8c 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -76,7 +76,7 @@ def query_sol(self, m, n, k, bias, dtype): def apply_skinny(self, m, n, k, inp_view, weights): if not self.use_skinny: return None - if (inp_view.dtype != torch.float16) and (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: + if ((inp_view.dtype != torch.float16) and (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: return None if m > 8 and 0 < n <= 4: out = torch.empty(inp_view.shape[0], From cd0ac7c15849aa1dd0d2dcf787f74dfa9d04aa80 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 7 Apr 2025 21:16:19 +0000 Subject: [PATCH 03/13] typo fixes (2) --- csrc/rocm/custom_kernels.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index f346d2a00c38..d271c9af3415 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -444,8 +444,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { for (uint32_t y = 0; y < YTILE; y++) - sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][0], 0, 0, 0); + sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, 0); } } } @@ -617,7 +617,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int i = 0; i < A_CHUNK * 2; i += 8) { for (int y = 0; y < YTILE; y++) sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][0], 0, 0, 0); + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, 0); } } } From 5ec47d7ba0284156760cfe6695eb241fd95587c3 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 7 Apr 2025 22:29:32 +0000 Subject: [PATCH 04/13] lint fixes --- csrc/rocm/custom.cu | 3 +- csrc/rocm/custom_kernels.cu | 637 ++++++++++++++++++------------------ 2 files changed, 327 insertions(+), 313 deletions(-) diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu index 1d39d5dd8910..ec0f525b39b5 100644 --- a/csrc/rocm/custom.cu +++ b/csrc/rocm/custom.cu @@ -37,7 +37,8 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, } void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, - const int N, const int Itp_in, cudaStream_t stream, const int CuCount); + const int N, const int Itp_in, cudaStream_t stream, + const int CuCount); void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t N_in, const int64_t Itp_in, const int64_t CuCount) { diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index d271c9af3415..b059ceabf8b8 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -329,7 +329,6 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, ///////////////////////////////////////////// - /*__device__ __forceinline__ int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; int rnds0 = N / nPrRnd; @@ -360,7 +359,8 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, }*/ #if defined(__HIP__MI300__) // TODO: Add NAVI support -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, @@ -416,7 +416,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; - for (int y=0; y< YTILE; y++) + for (int y = 0; y < YTILE; y++) bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * Kp / 2]))); } @@ -443,9 +443,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { - for (uint32_t y = 0; y < YTILE; y++) + for (uint32_t y = 0; y < YTILE; y++) sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, 0); + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, + 0); } } } @@ -520,7 +521,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300__) TODO: Add NAVI support -template +template __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, @@ -532,7 +534,8 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, #endif // defined(__HIP__MI300__) TODO: Add NAVI support #if defined(__HIP__MI300__) // TODO: Add NAVI support -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, @@ -588,7 +591,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K / 2) break; const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; - for (uint32_t y = 0; y < YTILE; y++) + for (uint32_t y = 0; y < YTILE; y++) bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * Kp / 2]))); } @@ -615,9 +618,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { - for (int y = 0; y < YTILE; y++) + for (int y = 0; y < YTILE; y++) sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, 0); + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, + 0); } } } @@ -692,7 +696,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300__) TODO: Add NAVI support -template +template __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, @@ -704,7 +709,8 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, @@ -779,7 +785,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const DTYPE* B_ = &B[(n + 0) * K + k_]; - for (uint32_t y = 0; y < YTILE; y++) + for (uint32_t y = 0; y < YTILE; y++) bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * K]))); } @@ -805,27 +811,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t y = 0; y < YTILE; y++) { - if constexpr (std::is_same_v) - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), "v"(bigB[y][k2].f[b])); + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), + "v"(bigB[y][k2].f[b])); + } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" + : "=v"(sum4[m][y]) + : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), + "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); } - if constexpr (std::is_same_v) - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 4; b++) - asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" - : "=v"(sum4[m][y]) - : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); - - } } } } @@ -833,75 +840,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - if constexpr (std::is_same_v) { - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2half(sum[m][i]); + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } } } - } } if constexpr (std::is_same_v) { #pragma unroll - for (int m = 0; m < M; m++) { + for (int m = 0; m < M; m++) { #pragma unroll - for (int y = 0; y < YTILE; y++) { - float accm = sum4[m][y][0]; - //for (int i=0; i<64; i++) - // accm += __shfl(sum[m][y][i%4], i); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - accm += __shfl_down(accm, 32); - accm += __shfl_down(accm, 16); - - sum4[m][y][0] = accm; + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + // for (int i=0; i<64; i++) + // accm += __shfl(sum[m][y][i%4], i); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + accm += __shfl_down(accm, 32); + accm += __shfl_down(accm, 16); + + sum4[m][y][0] = accm; + } } - } - if (threadIdx.x == 0) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } } } - } } n += CuCount * _WvPrGrp * YTILE; @@ -918,7 +925,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -928,7 +936,8 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, @@ -1015,10 +1024,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) while (n < N) { for (int i = 0; i < YTILE; i++) for (int m = 0; m < M; m++) - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) sum[m][i] = 0; - else - sum4[m][i] = {0}; + else + sum4[m][i] = {0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; @@ -1031,7 +1040,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const DTYPE* B_ = &B[(n + 0) * K + k_]; - for (uint32_t y = 0; y < YTILE; y++) + for (uint32_t y = 0; y < YTILE; y++) bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * K]))); } @@ -1057,27 +1066,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t y = 0; y < YTILE; y++) { - if constexpr (std::is_same_v) - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), "v"(bigB[y][k2].f[b])); + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), + "v"(bigB[y][k2].f[b])); + } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" + : "=v"(sum4[m][y]) + : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), + "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); } - if constexpr (std::is_same_v) - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 4; b++) - asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" - : "=v"(sum4[m][y]) - : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); - - } } } } @@ -1085,75 +1095,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - if constexpr (std::is_same_v) { - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2half(sum[m][i]); + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } } } - } } if constexpr (std::is_same_v) { #pragma unroll - for (int m = 0; m < M; m++) { + for (int m = 0; m < M; m++) { #pragma unroll - for (int y = 0; y < YTILE; y++) { - float accm = sum4[m][y][0]; - //for (int i=0; i<64; i++) - // accm += __shfl(sum[m][y][i%4], i); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - accm += __shfl_down(accm, 32); - accm += __shfl_down(accm, 16); - - sum4[m][y][0] = accm; + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + // for (int i=0; i<64; i++) + // accm += __shfl(sum[m][y][i%4], i); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + accm += __shfl_down(accm, 32); + accm += __shfl_down(accm, 16); + + sum4[m][y][0] = accm; + } } - } - if (threadIdx.x == 0) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } } } - } } n += CuCount * _WvPrGrp * YTILE; @@ -1171,7 +1181,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -1181,7 +1192,8 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, @@ -1252,11 +1264,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) while (n < N) { #endif for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) - if constexpr (std::is_same_v) - sum[m][i] = 0; - else - sum4[m][i] = {0}; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; @@ -1280,7 +1292,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (n >= N) continue; #endif - // Fetch the weight matrix from memory! + // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -1319,27 +1331,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t y = 0; y < YTILE; y++) { - if constexpr (std::is_same_v) - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), "v"(bigB[y][k2].f[b])); + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), + "v"(bigB[y][k2].f[b])); + } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" + : "=v"(sum4[m][y]) + : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), + "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); } - if constexpr (std::is_same_v) - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 4; b++) - asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" - : "=v"(sum4[m][y]) - : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); - - } } } } @@ -1355,75 +1368,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - if constexpr (std::is_same_v) { - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2half(sum[m][i]); + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } } } - } } if constexpr (std::is_same_v) { #pragma unroll - for (int m = 0; m < M; m++) { + for (int m = 0; m < M; m++) { #pragma unroll - for (int y = 0; y < YTILE; y++) { - float accm = sum4[m][y][0]; - //for (int i=0; i<64; i++) - // accm += __shfl(sum[m][y][i%4], i); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - accm += __shfl_down(accm, 32); - accm += __shfl_down(accm, 16); - - sum4[m][y][0] = accm; + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + // for (int i=0; i<64; i++) + // accm += __shfl(sum[m][y][i%4], i); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + accm += __shfl_down(accm, 32); + accm += __shfl_down(accm, 16); + + sum4[m][y][0] = accm; + } } - } - if (threadIdx.x == 0) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } } } - } } n += CuCount * _WvPrGrp * YTILE; @@ -1441,7 +1454,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -1479,78 +1493,77 @@ int mindiv(int N, int div1, int div2) { } void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, const int Itp_in, - cudaStream_t stream, const int CuCount = 0) { + const int K_in, const int N_in, const int Itp_in, + cudaStream_t stream, const int CuCount = 0) { dim3 grid(CuCount); -#define WVSPLTK(_DTYPE, _WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ +#define WVSPLTK(_DTYPE, _WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, \ + _UNRLb, _N) \ { \ dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltK_hf_sml_<_DTYPE, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + wvSpltK_hf_sml_<_DTYPE, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltK_hf_<_DTYPE, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + wvSpltK_hf_<_DTYPE, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSpltK_hf_big_<_DTYPE, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + wvSpltK_hf_big_<_DTYPE, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } \ } - if (Itp_in == 0) { // fp16 - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); - - switch (N_in) { - case 1: - WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 1) - break; - case 2: - WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 2) - break; - case 3: - WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 3) - break; - case 4: - WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 4) - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + - "," + std::to_string(K_in) + "," + - std::to_string(N_in)); - } - } - else if (Itp_in == 1) {// bf16 - __hip_bfloat16* af4 = reinterpret_cast<__hip_bfloat16*>(in_a); - const __hip_bfloat16* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast<__hip_bfloat16*>(out_c); - - switch (N_in) { - case 1: - WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 1) - break; - case 2: - WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 2) - break; - case 3: - WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 3) - break; - case 4: - WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 4) - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + - "," + std::to_string(K_in) + "," + - std::to_string(N_in)); - } + if (Itp_in == 0) { // fp16 + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + + switch (N_in) { + case 1: + WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + } else if (Itp_in == 1) { // bf16 + __hip_bfloat16* af4 = reinterpret_cast<__hip_bfloat16*>(in_a); + const __hip_bfloat16* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast<__hip_bfloat16*>(out_c); + + switch (N_in) { + case 1: + WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } } cudaError_t err = cudaGetLastError(); @@ -1576,12 +1589,12 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltKQ_hf_sml_ \ + wvSpltKQ_hf_sml_ \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltKQ_hf_ \ + wvSpltKQ_hf_ \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } \ From ad43bfe683ea57efe592f88c6997cbb54c21ff67 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 7 Apr 2025 22:43:31 +0000 Subject: [PATCH 05/13] more lint fixes --- csrc/rocm/custom_kernels.cu | 14 +++++++------- vllm/_custom_ops.py | 2 +- vllm/model_executor/layers/tuned_gemm.py | 5 +++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index b059ceabf8b8..97303a3046a8 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -811,8 +811,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll @@ -1066,8 +1066,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll @@ -1292,7 +1292,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (n >= N) continue; #endif - // Fetch the weight matrix from memory! + // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -1331,8 +1331,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d240cbc5cfa7..7fdf08010e00 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1417,7 +1417,7 @@ def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, - Itp: int, cu_count: int) -> None: + Itp: int, cu_count: int) -> None: torch.ops._rocm_C.wvSpltK(a, b, out, N, Itp, cu_count) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 7e76f4bbab8c..3ec3378ac974 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -76,7 +76,8 @@ def query_sol(self, m, n, k, bias, dtype): def apply_skinny(self, m, n, k, inp_view, weights): if not self.use_skinny: return None - if ((inp_view.dtype != torch.float16) and (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: + if ((inp_view.dtype != torch.float16) and + (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: return None if m > 8 and 0 < n <= 4: out = torch.empty(inp_view.shape[0], @@ -85,7 +86,7 @@ def apply_skinny(self, m, n, k, inp_view, weights): device='cuda') Itp = 1 #default bfloat16 if out_dtype == torch.float16: - Itp = 0 + Itp = 0 ops.wvSpltK(weights, inp_view, out, n, Itp, self.cu_count) return out elif m % 4 == 0 and n == 1 and k <= 8192: From 66d98cee69610669265da4a945d292336e92e318 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 7 Apr 2025 23:03:50 +0000 Subject: [PATCH 06/13] bug fix --- vllm/model_executor/layers/tuned_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 3ec3378ac974..f8441ef54b46 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -85,7 +85,7 @@ def apply_skinny(self, m, n, k, inp_view, weights): dtype=inp_view.dtype, device='cuda') Itp = 1 #default bfloat16 - if out_dtype == torch.float16: + if inp_view.dtype == torch.float16: Itp = 0 ops.wvSpltK(weights, inp_view, out, n, Itp, self.cu_count) return out From b586e74e38394f628bf9e7f565b8d64f29026b11 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 21:07:38 +0000 Subject: [PATCH 07/13] bug fix (2) --- csrc/rocm/ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 0dc5490cfbe1..584a9cdaabb2 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -9,7 +9,7 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t rows_per_block); void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount); + const int64_t N_in, const int64_t Itp_in, const int64_t CuCount); void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in, From d65ea534fb5567bb6a60939e1d31f82c3d0298af Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 00:23:40 +0000 Subject: [PATCH 08/13] bug fix (3) --- csrc/rocm/custom_kernels.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 97303a3046a8..3f3979b0c84e 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -772,7 +772,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) while (n < N) { for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; From 5cbee0b2dcac6d544d573f1fc7465147182113c7 Mon Sep 17 00:00:00 2001 From: Parker McLeod Date: Wed, 9 Apr 2025 01:22:18 +0000 Subject: [PATCH 09/13] bug fix (4) --- vllm/model_executor/layers/tuned_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index f8441ef54b46..2868fddc866f 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -79,7 +79,7 @@ def apply_skinny(self, m, n, k, inp_view, weights): if ((inp_view.dtype != torch.float16) and (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: return None - if m > 8 and 0 < n <= 4: + if m >= 8 and 0 < n <= 4: out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, @@ -89,7 +89,7 @@ def apply_skinny(self, m, n, k, inp_view, weights): Itp = 0 ops.wvSpltK(weights, inp_view, out, n, Itp, self.cu_count) return out - elif m % 4 == 0 and n == 1 and k <= 8192: + elif m % 4 == 0 and n == 1 and k <= 8192 and (inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, From 9f60e634cc003dda611d364d3b66ef6a37509b32 Mon Sep 17 00:00:00 2001 From: Parker McLeod Date: Wed, 9 Apr 2025 01:28:41 +0000 Subject: [PATCH 10/13] expand wvspltKQ to be used upto BS4. --- vllm/model_executor/layers/tuned_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 2868fddc866f..74fd6e547996 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -109,7 +109,7 @@ def scaled_mm( bias: Optional[torch.Tensor], ) -> torch.Tensor: n = inp.shape[0] - if (not VLLM_USE_ROCM_SKINNY_GEMM or n != 1 + if (not VLLM_USE_ROCM_SKINNY_GEMM or n > 4 or not current_platform.is_rocm() or is_mi250() or is_navi()): return torch._scaled_mm(inp, weight, From 8527696d00aec7fdb86d8c21fc25b129a0a65342 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Tue, 22 Apr 2025 21:49:52 +0000 Subject: [PATCH 11/13] lint fix --- vllm/model_executor/layers/tuned_gemm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 74fd6e547996..77110690b8f8 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -89,7 +89,8 @@ def apply_skinny(self, m, n, k, inp_view, weights): Itp = 0 ops.wvSpltK(weights, inp_view, out, n, Itp, self.cu_count) return out - elif m % 4 == 0 and n == 1 and k <= 8192 and (inp_view.dtype == torch.float16): + elif m % 4 == 0 and n == 1 and k <= 8192 and (inp_view.dtype + == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, From 5de7ef86f49f1f2261d8eead9c4ad743473eb233 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Wed, 23 Apr 2025 05:34:38 +0000 Subject: [PATCH 12/13] mfma builtin fix --- csrc/rocm/custom_kernels.cu | 66 +++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 3f3979b0c84e..0274a2c8736a 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -776,7 +776,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) sum[m][i] = 0; else - sum4[m][i] = {0}; + sum4[m][i] = {0,0,0,0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; @@ -832,10 +832,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" - : "=v"(sum4[m][y]) - : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), - "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -881,9 +878,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) { #pragma unroll for (int y = 0; y < YTILE; y++) { + //float accm1 = 0; + //for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[m][y][i%4], i); + float accm = sum4[m][y][0]; - // for (int i=0; i<64; i++) - // accm += __shfl(sum[m][y][i%4], i); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm) : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); @@ -899,13 +898,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm) : "0"(accm), "v"(accm), "v"(accm)); - accm += __shfl_down(accm, 32); - accm += __shfl_down(accm, 16); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); sum4[m][y][0] = accm; } } - if (threadIdx.x == 0) { + if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); @@ -1087,10 +1093,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" - : "=v"(sum4[m][y]) - : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), - "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -1137,8 +1140,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (int y = 0; y < YTILE; y++) { float accm = sum4[m][y][0]; - // for (int i=0; i<64; i++) - // accm += __shfl(sum[m][y][i%4], i); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm) : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); @@ -1154,13 +1155,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm) : "0"(accm), "v"(accm), "v"(accm)); - accm += __shfl_down(accm, 32); - accm += __shfl_down(accm, 16); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); sum4[m][y][0] = accm; } } - if (threadIdx.x == 0) { + if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); @@ -1352,10 +1360,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - asm("V_MFMA_F32_4X4X4_16B_BF16 %0, %2, %3, %4" - : "=v"(sum4[m][y]) - : "0"(sum4[m][y]), "v"(bigA[m][k2].h4[b]), - "v"(bigB[y][k2].h4[b]), "v"(sum4[m][y])); + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -1410,8 +1415,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (int y = 0; y < YTILE; y++) { float accm = sum4[m][y][0]; - // for (int i=0; i<64; i++) - // accm += __shfl(sum[m][y][i%4], i); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm) : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); @@ -1427,13 +1430,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm) : "0"(accm), "v"(accm), "v"(accm)); - accm += __shfl_down(accm, 32); - accm += __shfl_down(accm, 16); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); sum4[m][y][0] = accm; } } - if (threadIdx.x == 0) { + if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); From df245515975fc87f968927c624551fa609dcf51a Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Wed, 23 Apr 2025 05:48:10 +0000 Subject: [PATCH 13/13] lint fix --- csrc/rocm/custom_kernels.cu | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 0274a2c8736a..ae065ee8d837 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -776,7 +776,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) sum[m][i] = 0; else - sum4[m][i] = {0,0,0,0}; + sum4[m][i] = {0, 0, 0, 0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; @@ -832,7 +832,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -878,10 +879,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) { #pragma unroll for (int y = 0; y < YTILE; y++) { - //float accm1 = 0; - //for (int i=0; i<64; i++) - // accm1 += __shfl(sum4[m][y][i%4], i); - float accm = sum4[m][y][0]; asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm) @@ -1037,7 +1034,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) sum[m][i] = 0; else - sum4[m][i] = {0}; + sum4[m][i] = {0, 0, 0, 0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; @@ -1093,7 +1090,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -1280,7 +1278,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) sum[m][i] = 0; else - sum4[m][i] = {0}; + sum4[m][i] = {0, 0, 0, 0}; bigType bigA[M][UNRL]; bigType bigB[YTILE][UNRL]; @@ -1360,7 +1358,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (std::is_same_v) #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) - sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } }