diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu index c799dd273dae..ec0f525b39b5 100644 --- a/csrc/rocm/custom.cu +++ b/csrc/rocm/custom.cu @@ -37,14 +37,16 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, } void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, - const int N, cudaStream_t stream, const int CuCount); + const int N, const int Itp_in, cudaStream_t stream, + const int CuCount); void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount) { + const int64_t N_in, const int64_t Itp_in, const int64_t CuCount) { auto M = in_a.size(0); auto K = in_a.size(1); int N = N_in; - wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + int Itp = Itp_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, Itp, at::cuda::getCurrentCUDAStream(), CuCount); } diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 2d4a68fe3e7b..ae065ee8d837 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -329,8 +329,6 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, ///////////////////////////////////////////// -#define DTYPE half - /*__device__ __forceinline__ int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; int rnds0 = N / nPrRnd; @@ -361,7 +359,8 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, }*/ #if defined(__HIP__MI300__) // TODO: Add NAVI support -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, @@ -407,14 +406,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = {0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; + bigType bigB[YTILE][UNRL]; // Fetch the weight matrix from memory! for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { @@ -423,16 +415,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; - const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * Kp / 2]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -455,15 +440,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; - float aV[A_CHUNK * 2]; for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { - sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); - if (YTILE >= 2) - sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + for (uint32_t y = 0; y < YTILE; y++) + sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, + 0); } } } @@ -538,7 +521,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300__) TODO: Add NAVI support -template +template __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, @@ -550,7 +534,8 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, #endif // defined(__HIP__MI300__) TODO: Add NAVI support #if defined(__HIP__MI300__) // TODO: Add NAVI support -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, @@ -595,14 +580,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = {0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; + bigType bigB[YTILE][UNRL]; // Fetch the weight matrix from memory! for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { @@ -613,14 +591,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K / 2) break; const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * Kp / 2]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -643,15 +615,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K / 2) break; - float aV[A_CHUNK * 2]; for (uint32_t m = 0; m < M; m++) { for (int i = 0; i < A_CHUNK * 2; i += 8) { - sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); - if (YTILE >= 2) - sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + for (int y = 0; y < YTILE; y++) + sum[m][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[m][y], 0, 0, + 0); } } } @@ -726,7 +696,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300__) TODO: Add NAVI support -template +template __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, @@ -738,18 +709,22 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; half8 h8; }; @@ -760,7 +735,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ DTYPE s[1024 * 32]; //---------------------------------------------------- // Fetch the activation matrix to LDS @@ -793,60 +768,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[M][YTILE]; + half8 sum4[M][YTILE]; - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0, 0, 0, 0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + bigType bigB[YTILE][UNRL]; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -855,18 +788,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + const DTYPE* B_ = &B[(n + 0) * K + k_]; + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -877,7 +801,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { // if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); @@ -897,42 +820,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), + "v"(bigB[y][k2].f[b])); + } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -941,33 +842,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + if constexpr (std::is_same_v) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } } } - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { + #pragma unroll for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2half(sum[m][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[m][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } } } } @@ -986,7 +932,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -996,18 +943,22 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; half8 h8; }; @@ -1018,7 +969,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ DTYPE s[1024 * 32]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -1075,60 +1026,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; float sum[M][YTILE]; + half8 sum4[M][YTILE]; - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0, 0, 0, 0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- + bigType bigB[YTILE][UNRL]; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -1137,18 +1046,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + const DTYPE* B_ = &B[(n + 0) * K + k_]; + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((half8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1159,7 +1059,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); @@ -1170,51 +1069,29 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), + "v"(bigB[y][k2].f[b])); + } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -1223,33 +1100,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + if constexpr (std::is_same_v) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } } } - - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { + #pragma unroll for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[m][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } } } } @@ -1269,7 +1191,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -1279,51 +1202,36 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; half8 h8; }; - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ DTYPE s[1024 * 32]; - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- uint32_t commitColumn[YTILE]; for (uint32_t i = 0; i < YTILE; i++) { commitColumn[i] = 1; } - // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - // Check whether there will be fragmenation! - // This will happen only for the last wave! if (n < N && (n + YTILE) >= N) { uint32_t startColumn = N - YTILE; for (uint32_t i = 0; i < (n - startColumn); i++) { @@ -1332,30 +1240,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) n = startColumn; } - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- #define PCML #ifndef PCML for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - if (k_in >= min(K * M, 32 * 1024)) break; - - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); #endif @@ -1373,22 +1264,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) kFit = min(kFit, K); float sum[M][YTILE]; + half8 sum4[M][YTILE]; - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- #ifdef PCML int YW = (YTILE * _WvPrGrp); uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); @@ -1396,45 +1273,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else while (n < N) { #endif - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int m = 0; m < M; m++) + if constexpr (std::is_same_v) + sum[m][i] = 0; + else + sum4[m][i] = {0, 0, 0, 0}; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - bigType bigB9[UNRL]; - bigType bigB10[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- + bigType bigB[YTILE][UNRL]; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { #ifdef PCML if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS @@ -1462,18 +1309,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + const DTYPE* B_ = &B[(n + 0) * K + k_]; + for (uint32_t y = 0; y < YTILE; y++) + bigB[y][k2].h8 = loadnt((half8*)(&B_[y * K])); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1503,47 +1341,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + for (uint32_t y = 0; y < YTILE; y++) { + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(bigA[m][k2].f[b]), + "v"(bigB[y][k2].f[b])); + } + if constexpr (std::is_same_v) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[m][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[m][k2].h4[b], bigB[y][k2].h4[b], sum4[m][y], 0, 0, 0); } } } @@ -1560,33 +1376,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + if constexpr (std::is_same_v) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } } } - - if (threadIdx.x == 63) { + if constexpr (std::is_same_v) { + #pragma unroll for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[m][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[m][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[m][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2bfloat16(sum4[m][i][0]); + } } } } @@ -1606,7 +1467,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template +template __global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { @@ -1644,52 +1506,77 @@ int mindiv(int N, int div1, int div2) { } void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { + const int K_in, const int N_in, const int Itp_in, + cudaStream_t stream, const int CuCount = 0) { dim3 grid(CuCount); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); - -#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ +#define WVSPLTK(_DTYPE, _WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, \ + _UNRLb, _N) \ { \ dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + wvSpltK_hf_sml_<_DTYPE, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + wvSpltK_hf_<_DTYPE, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + wvSpltK_hf_big_<_DTYPE, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } \ } - switch (N_in) { - case 1: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 - break; - case 2: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 - break; - case 3: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 - break; - case 4: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + - "," + std::to_string(K_in) + "," + - std::to_string(N_in)); + if (Itp_in == 0) { // fp16 + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + + switch (N_in) { + case 1: + WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLTK(half, 16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLTK(half, 16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + } else if (Itp_in == 1) { // bf16 + __hip_bfloat16* af4 = reinterpret_cast<__hip_bfloat16*>(in_a); + const __hip_bfloat16* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast<__hip_bfloat16*>(out_c); + + switch (N_in) { + case 1: + WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLTK(__hip_bfloat16, 16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLTK(__hip_bfloat16, 16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } } cudaError_t err = cudaGetLastError(); @@ -1715,12 +1602,12 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + wvSpltKQ_hf_sml_ \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + wvSpltKQ_hf_ \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } \ @@ -1728,16 +1615,16 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, switch (N_in) { case 1: - WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 1) break; case 2: - WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 2) break; case 3: - WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 3) break; case 4: - WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 4) break; default: throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 0dc5490cfbe1..584a9cdaabb2 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -9,7 +9,7 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t rows_per_block); void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount); + const int64_t N_in, const int64_t Itp_in, const int64_t CuCount); void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in, diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index fab6a6942054..22f01d72ba8f 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -42,7 +42,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); rocm_ops.def( "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," - " int CuCount) -> ()"); + " int Itp_in, int CuCount) -> ()"); rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); rocm_ops.def( "wvSpltKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1099f0c5c72c..7fdf08010e00 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1417,8 +1417,8 @@ def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, - cu_count: int) -> None: - torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) + Itp: int, cu_count: int) -> None: + torch.ops._rocm_C.wvSpltK(a, b, out, N, Itp, cu_count) def wvSpltKQ(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 2f26bf5c365b..77110690b8f8 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -76,16 +76,21 @@ def query_sol(self, m, n, k, bias, dtype): def apply_skinny(self, m, n, k, inp_view, weights): if not self.use_skinny: return None - if inp_view.dtype != torch.float16 or k % 8 != 0: + if ((inp_view.dtype != torch.float16) and + (inp_view.dtype != torch.bfloat16)) or k % 8 != 0: return None - if m > 8 and 0 < n <= 4: + if m >= 8 and 0 < n <= 4: out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, device='cuda') - ops.wvSpltK(weights, inp_view, out, n, self.cu_count) + Itp = 1 #default bfloat16 + if inp_view.dtype == torch.float16: + Itp = 0 + ops.wvSpltK(weights, inp_view, out, n, Itp, self.cu_count) return out - elif m % 4 == 0 and n == 1 and k <= 8192: + elif m % 4 == 0 and n == 1 and k <= 8192 and (inp_view.dtype + == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, @@ -105,7 +110,7 @@ def scaled_mm( bias: Optional[torch.Tensor], ) -> torch.Tensor: n = inp.shape[0] - if (not VLLM_USE_ROCM_SKINNY_GEMM or n != 1 + if (not VLLM_USE_ROCM_SKINNY_GEMM or n > 4 or not current_platform.is_rocm() or is_mi250() or is_navi()): return torch._scaled_mm(inp, weight,