diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 8471541e144..700dc08f8c7 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -155,10 +155,10 @@ def get_rocm_tuning_space(use_fp16): # small search space, no pruning required # bypassLDS: block_n/num_warps=16 for perf block_m_range = [16, 32, 64, 128, 256] - block_n_range = [128] if use_fp16 else [64] + block_n_range = [128] if use_fp16 else [128] block_k_range = [128] if use_fp16 else [256] - num_warps_range = [8] if use_fp16 else [4] + num_warps_range = [8] if use_fp16 else [8] group_m_range = [1] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to @@ -211,6 +211,8 @@ def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]: keys, values = zip(*param_ranges.items()) for config_values in product(*values): config = dict(zip(keys, config_values)) + assert config['num_warps'] == config['BLOCK_SIZE_N'] // 16, \ + "num_warps should be equal to BLOCK_SIZE_N divided by 16" configs.append(config) return configs diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 483584dd804..813de437b61 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -10,7 +10,7 @@ create_kv_caches_with_random) NUM_BLOCKS = 1024 * 1024 -PARTITION_SIZE = 512 +PARTITION_SIZE = 256 @torch.inference_mode() @@ -101,7 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = 0.1 for _ in range(num_iters): if version == "v1": @@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE ) else: raise ValueError(f"Invalid version: {version}") @@ -174,13 +176,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: # Warmup. print("Warming up...") run_benchmark = run_cuda_benchmark - run_benchmark(num_iters=3, profile=False) + run_benchmark(num_iters=500, profile=False) # Benchmark. if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=1000, profile=False) + latency = run_benchmark(num_iters=10000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") diff --git a/benchmarks/kernels/tune_script.sh b/benchmarks/kernels/tune_script.sh index eea752b5ce2..f9b396bfd0e 100755 --- a/benchmarks/kernels/tune_script.sh +++ b/benchmarks/kernels/tune_script.sh @@ -4,6 +4,7 @@ export FUSED_MOE_PERSISTENT=1 export VLLM_MOE_PADDING=128 export VLLM_MOE_SHUFFLE=1 export TRITON_HIP_USE_NEW_STREAM_PIPELINE=1 +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ## ---- Mixtral fp8 tuning ---- ## diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index e4f6615ede1..c05ec89f03c 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,16 +1,185 @@ -#include "common.cuh" -#include "dispatch_utils.h" - +#include +#include #include -#ifndef USE_ROCM +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#if defined(USE_CUDA_FP8_FORMAT) + #include #include #else + #include #include #endif +#if defined(USE_CUDA_FP8_FORMAT) +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +#else + #include "amd/hip_float8.h" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + namespace vllm { +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + +template +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } else { + x = val / scale; + } + + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); +#if defined(USE_CUDA_FP8_FORMAT) + return static_cast(r); +#else + // Use hardware cvt instruction for fp8 on rocm + return c10::Float8_e4m3fnuz(hip_fp8(r).data, + c10::Float8_e4m3fnuz::from_bits()); +#endif +} + +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finish before consuming *scale. +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { + __shared__ float cache[1024]; + int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + cache[threadIdx.x] = tmp; + + __syncthreads(); + + // Now perform parallel reduction within the thread block + int ib = blockDim.x / 2; + while (ib != 0) { + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; + } + __syncthreads(); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (threadIdx.x == 0) { + atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); + } +} + +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +typedef struct __align__(4) { + FP8_TYPE x; + FP8_TYPE y; + FP8_TYPE z; + FP8_TYPE w; +} +float8x4_t; + +template +__device__ float thread_max_vec(scalar_t const* __restrict__ input, + int64_t const num_elems, int const tid, + int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + + int64_t const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = max(absmax_val, fabs(in_vec.x)); + absmax_val = max(absmax_val, fabs(in_vec.y)); + absmax_val = max(absmax_val, fabs(in_vec.z)); + absmax_val = max(absmax_val, fabs(in_vec.w)); + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = max(absmax_val, fabs(input[i])); + } + + return absmax_val; +} + +template +__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, + scalar_t const* __restrict__ input, + float const scale, + int64_t const num_elems, + int const tid, int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + float8x4_t* vectorized_out = reinterpret_cast(out); + + int64_t const num_vec_elems = num_elems >> 2; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + float8x4_t out_vec; + + out_vec.x = scaled_fp8_conversion( + static_cast(in_vec.x), scale); + out_vec.y = scaled_fp8_conversion( + static_cast(in_vec.y), scale); + out_vec.z = scaled_fp8_conversion( + static_cast(in_vec.z), scale); + out_vec.w = scaled_fp8_conversion( + static_cast(in_vec.w), scale); + vectorized_out[i] = out_vec; + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + out[i] = scaled_fp8_conversion( + static_cast(input[i]), scale); + } +} + template __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, const scalar_t* __restrict__ input, diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index ab8edd6d0f5..0d0c16d77b6 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -29,6 +29,11 @@ #define __HIP__MI300_MI250__ #endif +#if defined(__HIPCC__) && (defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -42,6 +47,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 @@ -51,6 +57,9 @@ using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(_Float16)))) _Float16; +typedef float16x2 _Half2; typedef struct _Half8 { _Half4 xy[2]; } _Half8; @@ -63,8 +72,13 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; +using _B8x4 = int32_t; //used in builtins using bit8_t = uint8_t; +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + ////// Non temporal load stores /////// template @@ -77,6 +91,23 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + auto res = make_float4(dat0,dat1,dat2,dat3); + return *reinterpret_cast<_B16x8*>(&res); +} + + template __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, const _B16x4& inpB, @@ -92,6 +123,21 @@ __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -140,17 +186,25 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { } t16; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; + union h2cvt { + __half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); + return u.b16x4; } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { - t16.b = __float2bfloat16(inp[i]); - ret[i] = t16.u; + union fcvt { + uint32_t u32; + float f32; + } u; + u.f32 = inp[i]; + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //RNE with no nan/inf check + ret[i] = uint16_t(u.u32 >> 16); + //t16.b = __float2bfloat16(inp[i]); + //ret[i] = t16.u; } return ret; } else { @@ -168,21 +222,26 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } t1, t2, res; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; + union h2cvt { + _B16x4 b16x4; + __half2 h2[2]; + } u1,u2,s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; + s.h2[0] = u1.h2[0] + u2.h2[0]; + s.h2[1] = u1.h2[1] + u2.h2[1]; + return s.b16x4; } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.b = t1.b + t2.b; - ret[i] = res.u; + union fcvt { + float f32; + uint32_t i32; + } u1,u2,s; + u1.i32 = uint32_t(inp1[i])<<16; + u2.i32 = uint32_t(inp2[i])<<16; + s.f32 = u1.f32 + u2.f32; + ret[i] = uint16_t(s.i32>>16); } return ret; } else { @@ -210,8 +269,547 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8_custom(const _B8x8 input, + const float scale) { + union { + floatx4 f32x4[2]; + vllm::Float8_ f32x8; + } tmpf8; + tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); + + tmpf8.f32x4[0] *= scale; + tmpf8.f32x4[1] *= scale; + + _B16x8 ret; + ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); + ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); + return ret; +} + +__device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0],inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2],inp[3]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = inp[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + int32_t tmpf8; + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(f0[0],f0[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(f1[0],f1[1]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + int32_t tmpf8; + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); + floatx4 tmpf; + tmpf[0] = f0[0]; + tmpf[1] = f0[1]; + tmpf[2] = f1[0]; + tmpf[3] = f1[1]; + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = tmpf[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + + + +template +__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { + union { + _B8x8 b8x8; + _B8x4 b8x4[2]; + } tmp; + tmp.b8x8 = input; + _B16x8 ret; + for (int i=0; i<2; i++) { + ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); + } + return ret; +} /////////////////////////////////////// +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; //partition size set to 256 TODO move to template param + //const int partition_size = 256; //blockDim.x; //TODO this could be head_size or partition_size + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; //partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO,4); + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + //shared_logits is used for multiple purposes + //__shared__ _B16x4 shared_logits[NWARPS][4][16][4 + 1]; + __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; + + //for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = WARP_SIZE / 16; //rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); //8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; //each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); //1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; //4xQKHE_16B across warp + + _B16x8 Qlocal[QKHELOOP][QK_SIZE_RATIO]; //note that 16 contiguous elements of Q should be fetched per lane for 8 bit cache types : QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + //constexpr int x = CONTIGUOUS_SCALAR_ELEMS_16B; //x is defined by vLLM as 16Bytes + + //constexpr int TLOOP1 = CONTIGUOUS_KV_ELEMS_16B_LOAD / 4; //mfma16x16x16 outputs 4 elements per lane: will be moved to match layout for V dwordx4 loads + //constexpr int TOKENS_PER_WARP1 = 16 * TLOOP1; //16 tokens across lanes * TLOOP factor + //constexpr int T_PAR_LOOP = T_PAR_SIZE / TOKENS_PER_WARP1 / NWARPS; + constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //sub partition of tokens per warp for qk calculation + constexpr int TLOOP = TOKENS_PER_WARP / 16; //each mfma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP][QKHELOOP]; //this could be B8x16 too + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + //for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + //each mfma takes QH16xT16x16HE across warp + //repeat mfmas across QKHELOOP dimension + //output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens across 4 rowsx4 tokens per lane + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + //fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + //fetch Q in shared + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; //+ rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + if (local_qhead_idx < GQA_RATIO) { + const scalar_t* q_fetch_ptr = q_ptr + lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; //this works for head size 128 : 16 lanes x 8 elems = 128 elems + const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i=0; i<2; i++) { + const int head_elem = lane16id * 2 + i; //element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem /4/4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][2*qkratio + i]; + } + } + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP;// 16 * T_PAR_SIZE / 256; + constexpr int VBLOCKS_PER_LANE = DIVIDE_ROUND_UP(VTOKENS_PER_LANE,BLOCK_SIZE); + constexpr int VTLOOP = NWARPS; //was * TOKENS_PER_WARP / ROWS_PER_WARP / VTOKENS_PER_LANE; + constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 + constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + //fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) { + const int vlocal_token_idx = vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this could be B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + + //v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD / BLOCK_SIZE; + //const int token_depth = vtoken_depth * VBLOCKS_PER_LANE + vblock_depth; + const int64_t vblock_number = static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + + //__syncthreads(); //if using shared Q + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + scale2 *= k_scale; + } + + floatx4 dout[TLOOP]; + + //Q stored in registers + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } else { //kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocaltmp.xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } + } + dout[token_depth] *= scale2; + } + + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + + const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i=0; i<4; i++) { + const float tmp = (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max,mask)); + } + + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i=0; i<4; i++) { + const float tmp = (local_token_idx + i < context_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { + exp_sum += __shfl_xor(exp_sum,mask); + } + + __syncthreads(); //sync before writing to shared mem + + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + //shared_qk_max[warpid][lane16id] = qk_max; + //shared_exp_sum[warpid][lane16id] = exp_sum; + const int qk_max_offset = warpid*16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS*16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } + + __syncthreads(); + + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + for (int w=0; w(dout[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); + } + } + + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + (wg_start_head_idx + qhead_idx) * max_num_partitions + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x4 outelems[VHELOOP]; + _B16x4 S_local[VTLOOP][2][2]; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + //for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int j=0; j<2; j++) { + for (int i=0; i<2; i++) { + const int offset = 4*rowid + 2*j + i; + const int offset1 = offset % 4; + const int offset2 = offset / 4; + S_local[vtoken_depth][j][i] = shared_logits[vtoken_depth][offset2][lane16id][offset1]; + } + } + //} + } + } + //v layout: 16he across lanes x 16 tokens per lane + + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i=0; i<2; i++) { + //TODO generalize this for 8 bit dtypes: each lane needs 2*vfetch_depth + 2 _B16x4 K/token dimension elems; each row is multiplied by a factor of 4 + //layout: lane in depth dimension | row across -> + //0 4 8 12 + //1 5 9 13 + //2 6 10 14 + //3 7 11 15 + const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; + const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP + const int offset2 = offset / 4; + + //if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j=0; j<2; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i=0; i<2; i++) { + const int offset = 4*rowid + 2*j + i; + const int offset1 = offset % 4; + const int offset2 = offset / 4; + tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], + S_local[vtoken_depth][j][i], + tmp_out); + } + } + } + + } + } + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= v_scale; + } + outelems[vhe_depth] = from_floatx4(tmp_out); + } + + __syncthreads(); + + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; //lane16 id head dimension; rowid head element dimension + } + + __syncthreads(); + + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int head_elem_idx = lane16id * 8; + const int offset1 = (head_elem_idx / 16)%4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4)%4; + for (int i=0; i<2; i++) { + vout[h].xy[i] = shared_logits[offset1][offset2][local_head_idx][offset3+i]; + } + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + + seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + const int head_elem_idx = lane16id * 8; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + +} +///////////////////////////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) template 16/2 // 8xtokens + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; @@ -312,8 +913,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; if constexpr (GQA_RATIO < 12) { @@ -370,6 +969,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } +#if 1 float alibi_slope[QHLOOP]; if (alibi_slopes != nullptr) { #pragma unroll @@ -380,6 +980,26 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( : 0.f; } } +#endif +#if 0 + float alibi_slope; + const int lane16id = laneid % 16; + if (alibi_slopes != nullptr) { + alibi_slope = (lane16id < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + lane16id] + : 0.f; + //#pragma unroll + // for (int h = 0; h < QHLOOP; h++) { + // for (int i=0; i<4; i++) { + // const int qhead_idx = h * 4 + i; + // alibi_slope[qhead_idx] = (qhead_idx < GQA_RATIO) + // ? alibi_slopes[wg_start_head_idx + qhead_idx] + // : 0.f; + // } + //} + //} + } +#endif // fetch vphysical block numbers up front if constexpr (GQA_RATIO >= 12) { @@ -392,6 +1012,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } +#if 1 //fetch vcache in normal case const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); @@ -416,7 +1037,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - } else { + } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) +#endif +#if 1 //fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block #pragma unroll @@ -435,23 +1059,73 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // iterate over all velems within block #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { - // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - const _B8x8 Vlocalb8 = v_ptrh8be[d]; - Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + //const _B8x8 Vlocalb8 = v_ptrh8be[d]; + //Vlocal[h][b * BLOCK_SIZE / 8 + d] = + // scaled_convert_b8x8(Vlocalb8, v_scale); } } } } - +#endif +#if 0 //cvt kf8 to kf/bf16 up front if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); + //scaled_convert_b8x8(Klocalb8[d], k_scale); + convert_b8x8_custom(Klocalb8[d]); } } +#endif + + /*Klocal[x] = scaled_convert_b8x8(Klocalb8[x], k_scale); \*/ + /*Klocal[x] = scaled_convert_b8x8_custom(Klocalb8[x], k_scale); \*/ +#define QK_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], \ + Klocal[x].xy[0], dout[h]);\ + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], \ + Klocal[x].xy[1], dout[h]);\ + } + //#pragma unroll + //for (int h = 0; h < QHLOOP; h++) { + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); + } + //} +#undef QK_mfma + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + scale2 *= k_scale; + } + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] *= scale2; + //if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // dout[h] *= k_scale; + //} + } +#if 0 #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], @@ -522,10 +1196,39 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } // KHELOOP>8 dout[h] *= scale; } +#endif + +#if 0 + if (alibi_slopes != nullptr) { + float alibi_slope_local[GQA_RATIO]; +#define DPP_BCAST_ASM(id) asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:id " : "=v"(alibi_slope_local[id]) : "v"(alibi_slope)); + //for (int head=0; head < 16; head++) { + //DPP_BCAST_ASM(0); + if constexpr(GQA_RATIO>0) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:0 " : "=v"(alibi_slope_local[0]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>1) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:1 " : "=v"(alibi_slope_local[1]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>2) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:2 " : "=v"(alibi_slope_local[2]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>3) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:3 " : "=v"(alibi_slope_local[3]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>4) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:4 " : "=v"(alibi_slope_local[4]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>5) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:5 " : "=v"(alibi_slope_local[5]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>6) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:6 " : "=v"(alibi_slope_local[6]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>7) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:7 " : "=v"(alibi_slope_local[7]) : "v"(alibi_slope));} + //} + + const int alibi_offset = global_token_idx - context_len + 1; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope_local[4*h+i] * alibi_offset; + } + } + } +#endif // transpose dout so that 4 token ids are in each lane, and 4 heads are across // 4 lanes #pragma unroll for (int h = 0; h < QHLOOP; h++) { +#if 1 floatx4 tmp = {0}; #pragma unroll for (int i = 0; i < 4; i++) { @@ -535,9 +1238,38 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; +#endif +#if 0 + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); + + bool mask = (lane4id % 2) == 1; + float tmp = dout[h][1]; + dout[h][1] = mask ? dout[h][0] : dout[h][1]; + dout[h][0] = mask ? tmp : dout[h][0]; + tmp = dout[h][3]; + dout[h][3] = mask ? dout[h][2] : dout[h][3]; + dout[h][2] = mask ? tmp : dout[h][2]; + + mask = (lane4id>>1) == 1; + tmp = dout[h][2]; + dout[h][2] = mask ? dout[h][0] : dout[h][2]; + dout[h][0] = mask ? tmp : dout[h][0]; + tmp = dout[h][3]; + dout[h][3] = mask ? dout[h][1] : dout[h][3]; + dout[h][1] = mask ? tmp : dout[h][1]; + + + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); + +#endif } const int lane4_token_idx = 4 * (global_token_idx >> 2); +#if 1 //alibi after transpose const int alibi_offset = lane4_token_idx - context_len + 1; if (alibi_slopes != nullptr) { #pragma unroll @@ -548,6 +1280,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } +#endif + + const int bpermute_mask = 4*(16*((laneid>>2)%4) + lane4id); #pragma unroll for (int h = 0; h < QHLOOP; h++) { @@ -559,11 +1294,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( : qk_max[h]; } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + for (int mask = WARP_SIZE / 2; mask >= 64; mask /= 2) { qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); } + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + + //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(qk_max[h]) : "v"(bpermute_mask), "v"(qk_max[h]) ); + + //qk_max[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, qk_max[h]); + auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&qk_max[h])); + qk_max[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); } + float exp_sum[QHLOOP]; #pragma unroll for (int h = 0; h < QHLOOP; h++) { @@ -576,17 +1322,28 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( exp_sum[h] += dout[h][i]; } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + for (int mask = WARP_SIZE / 2; mask >= 64; mask /= 2) { exp_sum[h] += __shfl_xor(exp_sum[h], mask); } + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + + //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(exp_sum[h]) : "v"(bpermute_mask), "v"(exp_sum[h]) ); + //exp_sum[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, exp_sum[h]); + auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&exp_sum[h])); + exp_sum[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); } + if (laneid<4) { #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int head_idx = 4 * h + lane4id; shared_qk_max[warpid][head_idx] = qk_max[h]; shared_exp_sum[warpid][head_idx] = exp_sum[h]; } + } } // warp within context __syncthreads(); @@ -630,7 +1387,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( logits[h] = from_floatx4(dout[h]); } - __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; if (warp_start_token_idx >= context_len) { // warp out of context #pragma unroll @@ -641,6 +1397,139 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } else { // warp in context +#if 0 //fetch v cache + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + //const _B8x8 Vlocalb8 = v_ptrh8be[d]; + //Vlocal[h][b * BLOCK_SIZE / 8 + d] = + // scaled_convert_b8x8(Vlocalb8, v_scale); + } + } + } + } +#endif +#if 0 //cvt vf8 ->f16/bf16 up front + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + for (int vh = 0; vh < VHELOOP; vh++) { + for (int b=0; b < VTLOOP; b++) { + //Vlocal[vh][b] = scaled_convert_b8x8(Vlocalb8[vh][b], v_scale); + Vlocal[vh][b] = convert_b8x8_custom(Vlocalb8[vh][b]); + } + } + } +#endif + + /*Vlocal[vh][x] = scaled_convert_b8x8(Vlocalb8[vh][x], v_scale);\*/ + /*Vlocal[vh][x] = scaled_convert_b8x8_custom(Vlocalb8[vh][x], v_scale);\*/ + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {\ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]);\ + }\ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[0], \ + acc[qh]); \ + acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[1], \ + acc[qh]); \ + } +#if 0 + floatx4 acc[QHLOOP][VHELOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { + acc[qh][vh] = {0}; + } + } +#endif + //#pragma unroll + // for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + //#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } + // iterate over tokens + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); +#if 0 + SV_mfma(8); + SV_mfma(9); + SV_mfma(10); + SV_mfma(11); + SV_mfma(12); + SV_mfma(13); + SV_mfma(14); + SV_mfma(15); +#endif + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + acc[qh] *= v_scale; + } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); + } + } + //} + +#if 0 + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh][vh]); + } + } +#endif + +#undef SV_mfma +#if 0 // iterate across heads #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { @@ -684,6 +1573,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); } } +#endif } // warp in context __syncthreads(); @@ -787,12 +1677,13 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); +#if 0 //disable this as mfma16 kernel does not support this optimization yet if (num_partitions == 1) { // if num_partitions==1, main kernel will write to out directly, no work in // reduction kernel return; } - +#endif constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -973,6 +1864,60 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} + template \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale, fp8_out_scale_ptr); + #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ paged_attention_ll4mi_QKV_kernel \ @@ -1036,7 +1991,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -1079,65 +2034,82 @@ void paged_attention_custom_launcher( OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - constexpr int NTHR = PARTITION_SIZE; + constexpr int NTHR = 256; //PARTITION_SIZE; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); + //LAUNCH_CUSTOM_ATTENTION(1); + LAUNCH_CUSTOM_ATTENTION_MFMA16(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); + //LAUNCH_CUSTOM_ATTENTION(2); + LAUNCH_CUSTOM_ATTENTION_MFMA16(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); + //LAUNCH_CUSTOM_ATTENTION(3); + LAUNCH_CUSTOM_ATTENTION_MFMA16(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); + //LAUNCH_CUSTOM_ATTENTION(4); + LAUNCH_CUSTOM_ATTENTION_MFMA16(4); break; case 5: - LAUNCH_CUSTOM_ATTENTION(5); + //LAUNCH_CUSTOM_ATTENTION(5); + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - LAUNCH_CUSTOM_ATTENTION(6); + //LAUNCH_CUSTOM_ATTENTION(6); + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - LAUNCH_CUSTOM_ATTENTION(7); + //LAUNCH_CUSTOM_ATTENTION(7); + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - LAUNCH_CUSTOM_ATTENTION(8); + //LAUNCH_CUSTOM_ATTENTION(8); + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - LAUNCH_CUSTOM_ATTENTION(9); + //LAUNCH_CUSTOM_ATTENTION(9); + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - LAUNCH_CUSTOM_ATTENTION(10); + //LAUNCH_CUSTOM_ATTENTION(10); + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - LAUNCH_CUSTOM_ATTENTION(11); + //LAUNCH_CUSTOM_ATTENTION(11); + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - LAUNCH_CUSTOM_ATTENTION(12); + //LAUNCH_CUSTOM_ATTENTION(12); + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - LAUNCH_CUSTOM_ATTENTION(13); + //LAUNCH_CUSTOM_ATTENTION(13); + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - LAUNCH_CUSTOM_ATTENTION(14); + //LAUNCH_CUSTOM_ATTENTION(14); + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - LAUNCH_CUSTOM_ATTENTION(15); + //LAUNCH_CUSTOM_ATTENTION(15); + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - LAUNCH_CUSTOM_ATTENTION(16); + //LAUNCH_CUSTOM_ATTENTION(16); + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); @@ -1149,11 +2121,14 @@ void paged_attention_custom_launcher( // note there are cases with graphing where max_context_len is the max // supported by graphing, not the actual max among all the sequences: in that // case reduction kernel will still run but return immediately - if (max_context_len > PARTITION_SIZE) { + + //below optimization is not yet implemented in mfma16 kernel + //if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); // support upto 8*64*256=128K context length +#if 1 switch (npar_loops) { case 1: LAUNCH_CUSTOM_REDUCTION(1); @@ -1183,7 +2158,8 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } - } +#endif + //} //if max_context_len > partition_size } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ @@ -1200,14 +2176,12 @@ void paged_attention_custom_launcher( case 256: \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ break; \ - case 512: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \ - break; \ default: \ TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ break; \ } - +/* +*/ #if defined(__HIPCC__) && defined(__gfx90a__) #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if (fp8_out_scale) { \ @@ -1229,19 +2203,17 @@ void paged_attention_custom_launcher( case 16: \ CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } - +/* + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ +*/ #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ - case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ - break; \ case 128: \ CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ @@ -1249,7 +2221,11 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } - +/* + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ + break; \ +*/ void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 13c92dadcd4..745bc9b0807 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -19,30 +19,38 @@ # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 + # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 128*1024+4321 # Arbitrary values for testing NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} DTYPES = [ torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] -NUM_GEN_SEQS = [7] # Arbitrary values for testing +] if not current_platform.is_rocm() else [torch.half,torch.bfloat16] +NUM_GEN_SEQS = [17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +# NUM_HEADS = [(64, 8), (26,2), (16,1), (32,32)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 120, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [128] -BLOCK_SIZES = [16, 32] -USE_ALIBI = [False, True] -KV_CACHE_DTYPE = ["auto", "fp8"] +BLOCK_SIZES = [16] +USE_ALIBI = [False] +# USE_ALIBI = [True] +KV_CACHE_DTYPE = ["auto","fp8"] SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +REF_TENSOR = None +CMP_TENSOR = None def ref_masked_attention( query: torch.Tensor, @@ -117,7 +125,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) + ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -181,6 +189,9 @@ def test_paged_attention( device) key_cache, value_cache = key_caches[0], value_caches[0] + #value_cache = torch.ones_like(value_cache) + #key_cache = torch.ones_like(key_cache) + # Using default kv_scale k_scale = v_scale = 1.0 @@ -213,7 +224,7 @@ def test_paged_attention( elif version in ("v2", "rocm"): if current_platform.is_rocm(): - PARTITION_SIZE = 1024 if version == "v2" else 512 + PARTITION_SIZE = 256 if version == "v2" else PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -248,13 +259,13 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, + '''opcheck(torch.ops._C.paged_attention_v2, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + and block_size == BLOCK_SIZES[0]))''' else: ops.paged_attention_rocm( @@ -275,15 +286,17 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE, ) - opcheck(torch.ops._rocm_C.paged_attention, + '''opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), + kv_cache_dtype, k_scale, v_scale, None, PARTITION_SIZE), cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + and block_size == BLOCK_SIZES[0]))''' else: raise AssertionError(f"Unknown version: {version}") @@ -298,14 +311,14 @@ def test_paged_attention( dtype=dtype, device=device) ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = dequantized_key_cache + key_cache = k_scale * dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = dequantized_value_cache + value_cache = v_scale * dequantized_value_cache ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( @@ -328,9 +341,13 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - atol, rtol = 1e-3, 1e-5 + atol, rtol = 1e-4, 1e-5 if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 + atol, rtol = 5e-4, 1e-5 + #bf16 rounding is handled via truncation in new kernel, this increses error + if dtype == torch.bfloat16: + atol = 1e-3 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) @@ -433,4 +450,4 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 66aa2600226..8cdf63b31dd 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,123 +1,123 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 2, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "256": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "512": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1024": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, @@ -126,18 +126,18 @@ "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 58f9e38f522..aa4d65451e0 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,100 +1,100 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 2, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 2, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -108,7 +108,7 @@ "waves_per_eu": 0 }, "512": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, @@ -117,27 +117,27 @@ "waves_per_eu": 0 }, "1024": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 90d0e6f6ba3..ceef67e786a 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,53 +1,53 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -56,53 +56,53 @@ "32": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "256": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 @@ -110,34 +110,34 @@ "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1024": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2048": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -154,11 +154,57 @@ }, "4096": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 9d1b36acd64..e2b865cc03a 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,19 +1,19 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -28,116 +28,116 @@ }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "256": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "512": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1024": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2048": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 9e28dade2ce..0b49304cd83 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,143 +1,143 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "256": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1024": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2048": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index a971953062c..d28b4f17560 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,100 +1,100 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -110,23 +110,23 @@ "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1024": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, @@ -135,7 +135,7 @@ "waves_per_eu": 0 }, "2048": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 1bad9550f06..2ab10c87d21 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,108 +1,108 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "256": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 @@ -110,7 +110,7 @@ "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -126,18 +126,18 @@ "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json index 91260051a53..2d4a79bc633 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -1,10 +1,10 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -18,65 +18,65 @@ "waves_per_eu": 0 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -84,17 +84,17 @@ "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -117,16 +117,16 @@ "waves_per_eu": 0 }, "1024": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "1536": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, @@ -135,9 +135,9 @@ "waves_per_eu": 0 }, "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, @@ -145,20 +145,66 @@ }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 0 + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 } } diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index adca2864707..6caf96c812a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -388,18 +388,28 @@ def get_default_config( is_marlin: bool, ) -> Dict[str, int]: config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 128 if dtype is None else 128, + 'BLOCK_SIZE_K': 128 if dtype is None else 256, + 'GROUP_SIZE_M': 1, + "num_warps": 8 if dtype is None else 8, + "num_stages": 2, + "waves_per_eu": 0, + "kpack": 2, # garbage w/o this + "matrix_instr_nonkdim": 16 # force 16 mfma even for block_m=32 } # A heuristic: fused marlin works faster with this config for small M if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 + 'BLOCK_SIZE_N': 128 if dtype is None else 128, + 'BLOCK_SIZE_K': 128 if dtype is None else 256, + 'GROUP_SIZE_M': 1, + "num_warps": 8 if dtype is None else 8, + "num_stages": 2, + "waves_per_eu": 0, + "kpack": 2, # garbage w/o this + "matrix_instr_nonkdim": 16 # force 16 mfma even for block_m=32 } return config diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9b0c3b06649..8a6d4ee0832 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -379,7 +379,7 @@ def apply(self, def permute_weight_fp8(x: torch.Tensor) -> torch.Tensor: ## Hardcode BLOCK_K and BLOCK_N BK = 256 - BN = 64 #256 #128 + BN = 128 #256 #64 x_ = x x_ = x_.view(x.shape[0], x.shape[1] // BN, BN // 16, 16, x.shape[2] // BK,