diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 6fa2e77299eb0..32f86e20c172f 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -19,7 +19,7 @@ typedef tile<16, 8, half2> tile_C_VKQ_16; // nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. // nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). // Q_in_reg: whether the Q values should be kept permanently in registers. -// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nstages_max: max. number of pipeline stages for cp_async (if available), 1 stage is faster for potential KQ skips, 0 means synchronous data loading. // nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. // nbatch_V2: number of V half2 values in direction of DV to load in parallel. // nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. @@ -32,7 +32,7 @@ struct fattn_mma_f16_config< 64, 64> { static constexpr int nbatch_fa = 64; static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + static constexpr int nstages_max = 2; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { return 32; @@ -64,7 +64,7 @@ struct fattn_mma_f16_config< 80, 80> { static constexpr int nbatch_fa = 64; static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + static constexpr int nstages_max = 2; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { return 40; @@ -96,7 +96,7 @@ struct fattn_mma_f16_config< 96, 96> { static constexpr int nbatch_fa = 64; static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + static constexpr int nstages_max = 2; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { return 48; @@ -128,7 +128,7 @@ struct fattn_mma_f16_config<112, 112> { static constexpr int nbatch_fa = 64; static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + static constexpr int nstages_max = 2; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { return 56; @@ -160,7 +160,7 @@ struct fattn_mma_f16_config<128, 128> { static constexpr int nbatch_fa = 64; static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + static constexpr int nstages_max = 2; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { return 64; @@ -192,7 +192,7 @@ struct fattn_mma_f16_config<256, 256> { static constexpr int nbatch_fa = 32; static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + static constexpr int nstages_max = 2; static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { return 128; @@ -210,20 +210,12 @@ struct fattn_mma_f16_config<256, 256> { return 128; } - static int get_nbatch_combine_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 128 : 64; - } - return 64; + static int get_nbatch_combine_host(const int /*cc*/, const int ncols) { + return ncols <= 16 ? 128 : 64; } static constexpr __device__ int get_nbatch_combine_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING return ncols <= 16 ? 128 : 64; -#else - GGML_UNUSED(ncols); - return 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING } }; @@ -232,7 +224,7 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nbatch_fa = 32; static constexpr int nwarps_max = 8; static constexpr bool Q_in_reg = false; - static constexpr int nstages_target = 1; + static constexpr int nstages_max = 1; static int get_nbatch_K2_host(const int cc, const int ncols) { if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { @@ -392,7 +384,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -421,12 +414,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #ifdef NEW_MMA_AVAILABLE typedef fattn_mma_f16_config c; -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. @@ -451,11 +438,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (nstages > 1) { static_assert(!mla, "multi-stage loading not implemented for MLA"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); - constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -463,6 +447,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } + if constexpr (nstages < 2) { + if constexpr (nstages > 0) { + cp_async_wait_all(); + } + __syncthreads(); + + static_assert(c::nbatch_fa == WARP_SIZE || c::nbatch_fa == 2*WARP_SIZE, "bad nbatch_fa"); + bool skip; + if constexpr (ncols1 == 1) { + const float2 tmp = __half22float2(tile_mask[c::nbatch_fa == WARP_SIZE ? threadIdx.x % (WARP_SIZE/2) : threadIdx.x]); + skip = isinf(tmp.x) && isinf(tmp.y); + } else { + skip = true; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += WARP_SIZE/(c::nbatch_fa/2)) { + const int j = c::nbatch_fa == WARP_SIZE ? j0 + threadIdx.x / (WARP_SIZE/2) : j0; + const int i = c::nbatch_fa == WARP_SIZE ? threadIdx.x % (WARP_SIZE/2) : threadIdx.x; + + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); + skip = skip && isinf(tmp.x) && isinf(tmp.y); + } + } + + if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); + if constexpr (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + } + } + return; + } + } + + if constexpr (nstages > 1) { + constexpr bool use_cp_async = true; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); + } + #pragma unroll for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; @@ -780,7 +811,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -806,12 +837,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( typedef fattn_mma_f16_config c; -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; @@ -926,13 +951,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -1199,7 +1224,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1243,11 +1268,14 @@ static __global__ void flash_attn_ext_f16( const int ne3) { #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) + typedef fattn_mma_f16_config c; + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { NO_DEVICE_CODE; return; } + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING if (ncols1*ncols2 > 32) { NO_DEVICE_CODE; @@ -1255,9 +1283,19 @@ static __global__ void flash_attn_ext_f16( } #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); +#ifdef CP_ASYNC_AVAILABLE + if (nstages == 0 || nstages > c::nstages_max) { + NO_DEVICE_CODE; + return; + } +#else + if (nstages != 0) { + NO_DEVICE_CODE; + return; + } +#endif // CP_ASYNC_AVAILABLE - typedef fattn_mma_f16_config c; + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); @@ -1307,12 +1345,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1347,7 +1385,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else @@ -1365,7 +1403,7 @@ static __global__ void flash_attn_ext_f16( #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const int id = ggml_cuda_get_device(); @@ -1373,8 +1411,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml typedef fattn_mma_f16_config c; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; - constexpr int ncols = ncols1 * ncols2; constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. constexpr int cols_per_warp = ntiles * tile_B::I; @@ -1410,7 +1446,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1421,7 +1457,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1437,16 +1473,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } -#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ - template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ +#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2, nstages) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NSTAGES(DKQ, DV, ncols1, ncols2) \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2, 0); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2, 1); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2, 2); \ -#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ - extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ - extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ - extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ - extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ - extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ + DECL_FATTN_MMA_F16_CASE_ALL_NSTAGES(DKQ, DV, (ncols)/ 1, 1); \ + DECL_FATTN_MMA_F16_CASE_ALL_NSTAGES(DKQ, DV, (ncols)/ 2, 2); \ + DECL_FATTN_MMA_F16_CASE_ALL_NSTAGES(DKQ, DV, (ncols)/ 4, 4); \ + DECL_FATTN_MMA_F16_CASE_ALL_NSTAGES(DKQ, DV, (ncols)/ 8, 8); \ + DECL_FATTN_MMA_F16_CASE_ALL_NSTAGES(DKQ, DV, (ncols)/16, 16); \ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) @@ -1477,6 +1518,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) // The number of viable configurations for Deepseek is very limited: -extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16, 0); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16, 1); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16, 0); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16, 1); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16, 0); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16, 1); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 6bc0096cc65e6..30e935ad1c0dc 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,6 +8,30 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_nstages(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + typedef fattn_mma_f16_config c; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * mask = dst->src[3]; + + if (!cp_async_available(cc)) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if constexpr (c::nstages_max > 1) { + static_assert(c::nstages_max == 2, "bad nstages_max"); + if (Q->ne[3] == 1 && mask) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; @@ -15,22 +39,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con if constexpr (ncols2 <= 8) { if (Q->ne[1] <= 8/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_nstages(ctx, dst); return; } } if (Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_nstages(ctx, dst); return; } if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_nstages(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_nstages(ctx, dst); } template @@ -325,7 +349,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && + (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion; const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu index fb26abeb0dab3..af174b1ab5e65 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16, 0); +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc16829021f90..d992d0de98029 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8); -DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); -DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); -DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); -DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); -DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu index 9d3cfd8edf74b..d0b233a7aeecb 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1); -DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1); -DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1); -DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1); -DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1); -DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu index 2e1883af40ed2..4b953682c7e4f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2); -DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2); -DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2); -DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2); -DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2); -DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 2074e954a32f0..6583fc4e1e75f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4); -DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4); -DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); -DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); -DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); -DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu index f011a208cd270..c70a3ad3e46e6 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16, 0); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 24c64cf000fec..1cf56d0c6ba8d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4); -DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4); -DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); -DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); -DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); -DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939e49d..7f446524a4774 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8); -DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); -DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); -DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); -DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); -DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu index 0543532ea3479..d49b34d4507c7 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1); -DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1); -DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1); -DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1); -DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1); -DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu index 407b6cf4c7020..c750bf993c463 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2); -DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2); -DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2); -DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2); -DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2); -DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu index f5fd0e2369cf2..c79e8576109d3 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16, 0); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu index 5e46685024b84..ed74f4acf57ae 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2); -DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2); -DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2); -DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2); -DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2); -DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 1ada657f194c4..5ff4cbf2da46a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4); -DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4); -DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); -DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); -DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); -DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b4141e0..301c6b2749a5e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8); -DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); -DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); -DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); -DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); -DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu index 0d7a9c728537d..35937df99de13 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1); -DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1); -DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1); -DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1); -DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1); -DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu index 9d5a9976f0ed1..0b78612780c4c 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1); -DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1); -DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1); -DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1); -DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1); -DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu index a6e6f093dcb24..ab2a81319f8ea 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2); -DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2); -DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2); -DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2); -DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2); -DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 86d4ffae27c28..547b79f1d23ce 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4); -DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4); -DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); -DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); -DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); -DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6de58..ddbd5acaa5b91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -2,9 +2,21 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8); -DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); -DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); -DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); -DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); -DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8, 0); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8, 0); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8, 0); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8, 0); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8, 0); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8, 0); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3428113dc8fd2..4bdd886a78fac 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2}, {nstages});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -71,7 +71,11 @@ def get_head_sizes(type_k, type_v): if head_size_kq == 576 and ncols2 != 16: continue head_size_v = head_size_kq if head_size_kq != 576 else 512 - f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) + for nstages in [0, 1, 2]: + if head_size_kq == 576 and nstages == 2: + continue + f.write(SOURCE_FATTN_MMA_CASE.format( + ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v, nstages=nstages)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: