From dd40b1e7424058c95bdd6ca5cdf142443d1798f3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 03:21:18 +0000 Subject: [PATCH 01/77] fp8 support Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 9 +++- .../layers/fused_moe/__init__.py | 1 + vllm/model_executor/layers/fused_moe/layer.py | 43 +++++++++++++++++++ .../layers/fused_moe/pplx_prepare_finalize.py | 6 +++ .../model_executor/layers/quantization/fp8.py | 3 ++ 5 files changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 779fa1df086d..c81b290845ad 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -70,6 +70,11 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): + if config.dtype == torch.torch.float8_e4m3fn: + config_dtype = torch.bfloat16 + else: + config_dtype = config.dtype + A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", @@ -97,7 +102,7 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) + [torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("block_shape", [None]) @pytest.mark.parametrize("per_act_token_quant", [False]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, @@ -151,6 +156,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") compute_tl_dtype = { + torch.torch.float8_e4m3fn: tl.bfloat16, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 @@ -196,6 +202,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, block_shape) rtol, atol = { + torch.torch.float8_e4m3fn: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), torch.float32: (1e-2, 1e-2), diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3d40879b4ccb..897e6700e7c4 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -38,6 +38,7 @@ def get_config() -> Optional[dict[str, Any]]: "FusedMoEPrepareAndFinalize", "override_config", "get_config", + "MOE_DP_CHUNK_SIZE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6f9770262856..77e421554823 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -378,6 +378,47 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + def set_prepare_finalize( + self, + dp_size: int, + world_size: int, + prepare_finalize: FusedMoEPrepareAndFinalize, + ) -> bool: + assert self.fused_experts == fused_experts + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + return True + def forward_cuda( self, layer: torch.nn.Module, @@ -1299,6 +1340,8 @@ def select_experts( topk_ids = topk_ids.to(dtype=indices_type) + assert topk_ids.dtype == indices_type + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 45e813287d3f..d67b3b9f0e58 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -127,6 +127,10 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) + if a1q_scale is not None and a1q_scale.dim() == 1: + assert a1q_scale.numel() == 1 + a1q_scale = a1q_scale.view(1, 1) + # rem_experts need to be 0 for pplx to work properly. rem_experts = num_experts % self.world_size assert rem_experts == 0 @@ -162,6 +166,8 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None + #print(f"SCALE= {a1q_scale.shape}") + self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0295f5e2a1c8..926a5f4a7b01 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -481,6 +481,9 @@ def __init__(self, quant_config: Fp8Config): block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) + self.use_pplx_kernels = False + self.rocm_aiter_moe_enabled = False + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): From a0511081afedb331a4ce7bb8cf8352df70114fdb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 21:17:52 +0000 Subject: [PATCH 02/77] wip Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 1 + .../layers/fused_moe/fused_batched_moe.py | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index c81b290845ad..bc8c82974c48 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -7,6 +7,7 @@ import pytest import torch import triton.language as tl +from typing import Optional from tests.kernels.moe.utils import (batched_moe, make_quantized_test_activations, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 37a109857ac3..d1dc4e872c7c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -539,6 +539,8 @@ def __init__( self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size + self.use_fp8_w8a8 = use_fp8_w8a8 + self.block_shape = block_shape @property def activation_formats( @@ -571,6 +573,66 @@ def workspace_shapes( workspace2 = (self.max_num_tokens * num_dp, N) return (workspace13, workspace2, workspace13, a.dtype) + def native_w8a8_block_matmul(A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert self.block_shape is not None and len(self.block_shape) == 2 + block_n, block_k = self.block_shape[0], self.block_shape[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + def apply( self, output: torch.Tensor, From 43f9cfe1c7d7a3689f897f98e967ff3f4bfb4bc7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 23:13:43 +0000 Subject: [PATCH 03/77] test Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index bc8c82974c48..ad42ccd07e0e 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -71,10 +71,10 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - if config.dtype == torch.torch.float8_e4m3fn: - config_dtype = torch.bfloat16 + if config.in_dtype == torch.torch.float8_e4m3fn: + config_in_dtype = torch.bfloat16 else: - config_dtype = config.dtype + config_in_dtype = config.in_dtype A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), @@ -157,7 +157,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") compute_tl_dtype = { - torch.torch.float8_e4m3fn: tl.bfloat16, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 @@ -202,8 +201,15 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, A_scale, B_scale, block_shape) + ref_output2 = ref_impl(tensors.A, + tensors.B, + ref_output2, + tensors.num_expert_tokens, + A_scale, + B_scale, + block_shape[-2:]) + rtol, atol = { - torch.torch.float8_e4m3fn: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), torch.float32: (1e-2, 1e-2), From 39a2ab3e748a79ace0db9b24af54f254cad9693b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 23:20:54 +0000 Subject: [PATCH 04/77] basic working test Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index d1dc4e872c7c..041b915a432a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -17,38 +17,38 @@ @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr): - + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr +): offs_k = tl.arange(0, BLOCK_K) if use_w8a16: @@ -312,22 +312,22 @@ def batched_triton_kernel( def invoke_moe_batched_triton_kernel( - A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - block_shape: Optional[list[int]] = None): - + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None +): assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) From b5996eca9cb7f5e63419f5791ada29984024b398 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 22 May 2025 23:29:08 +0000 Subject: [PATCH 05/77] tests + fix Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 122 ++++++++---------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 041b915a432a..47e03ba30db2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -11,8 +11,11 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, + _resize_cache, + cdiv) @triton.jit @@ -400,6 +403,8 @@ def __init__( self.dp_size = dp_size self.rank = rank self.max_num_tokens = max_num_tokens + self.use_fp8_w8a8 = use_fp8_w8a8 + self.block_shape = block_shape @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -435,6 +440,8 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) + _, block_k = self.block_shape + num_tokens, hidden_dim = a1.size() topk = topk_ids.size(1) @@ -456,9 +463,14 @@ def prepare( dtype=b_type, device=a1.device) - b_a1_scale = None - - assert quant_config.quant_dtype is None, "quantization NYI" + if self.use_fp8_w8a8: + k_tiles = (hidden_dim + block_k - 1) // block_k + b_a1_scale = torch.zeros( + (num_local_experts, self.max_num_tokens, k_tiles), + dtype=torch.float32, + device=a1.device) + else: + b_a1_scale = None first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts @@ -469,10 +481,14 @@ def prepare( if rows == 0: continue idx = expert_id - first_expert - b_a1[idx, :rows, :] = a1[:topks.numel()][topks] - tokens_per_expert[idx] = rows + if self.use_fp8_w8a8: + # TODO: use _fp8_quantize + b_a1[idx, :rows, :], tmp_scale = per_token_group_quant_fp8(rhs, block_k) + b_a1_scale[idx, :rows] = tmp_scale # inline? + else: + b_a1[idx, :rows, :] = rhs - assert b_a1_scale is None or b_a1_scale.ndim == 3 + tokens_per_expert[idx] = rows return b_a1, b_a1_scale, tokens_per_expert, None, None @@ -573,66 +589,6 @@ def workspace_shapes( workspace2 = (self.max_num_tokens * num_dp, N) return (workspace13, workspace2, workspace13, a.dtype) - def native_w8a8_block_matmul(A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor): - """This function performs matrix multiplication with block-wise - quantization using native torch. - It is agnostic to the input data type and can be used for both int8 and - fp8 data types. - - It takes two input tensors `A` and `B` (int8) with scales `As` and - `Bs` (float32). - The output is returned in the specified `output_dtype`. - """ - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert self.block_shape is not None and len(self.block_shape) == 2 - block_n, block_k = self.block_shape[0], self.block_shape[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - def apply( self, output: torch.Tensor, @@ -833,6 +789,7 @@ def apply( if self.use_fp8_w8a8: intermediate_cache1.fill_(0) + assert not self.use_fp8_w8a8 or a1q_scale is not None # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, @@ -866,8 +823,33 @@ def apply( per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) - qintermediate_cache2 = qintermediate_cache2.view( - (E, -1, ic2_hidden_size)) + # TODO (varun) : support w8a8 + #assert not self.use_fp8_w8a8 + if self.use_fp8_w8a8: + per_act_token = False + qintermediate_cache2 = torch.zeros_like(intermediate_cache2, + dtype=torch.float8_e4m3fn) + block_n = self.block_shape[0] + n_tiles = ((N // 2) + block_n - 1) // block_n + scale_shape = (E, num_tokens, n_tiles) + a2q_scale = torch.zeros(scale_shape, + dtype=torch.float32, + device=hidden_states.device) + for e in range(E): + num_tokens = expert_num_tokens[e] + if num_tokens > 0: + #qintermediate_cache2[e], tmp_scale = _fp8_quantize( + # intermediate_cache2[e], + # a2_scale[e] if a2_scale is not None else None, + # per_act_token, self.block_shape) + qintermediate_cache2[e, :num_tokens, :], tmp_scale = per_token_group_quant_fp8( + intermediate_cache2[e, :num_tokens], block_n) + #print(a2q_scale[e, :tmp_scale.shape[0]].shape) + #print(tmp_scale.shape) + a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale + else: + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, From 356d4d7a024062c78ba1480b70ab607674d055e9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 27 May 2025 17:45:04 +0000 Subject: [PATCH 06/77] stuff Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 9 +- .../layers/fused_moe/fused_batched_moe.py | 116 +++++++++--------- 2 files changed, 59 insertions(+), 66 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index ad42ccd07e0e..8abbbb273623 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -7,7 +7,6 @@ import pytest import torch import triton.language as tl -from typing import Optional from tests.kernels.moe.utils import (batched_moe, make_quantized_test_activations, @@ -201,12 +200,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, A_scale, B_scale, block_shape) - ref_output2 = ref_impl(tensors.A, - tensors.B, - ref_output2, - tensors.num_expert_tokens, - A_scale, - B_scale, + ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2, + tensors.num_expert_tokens, A_scale, B_scale, block_shape[-2:]) rtol, atol = { diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 47e03ba30db2..735f9884cadf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -11,47 +11,44 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache, - cdiv) @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr -): + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): offs_k = tl.arange(0, BLOCK_K) if use_w8a16: @@ -315,22 +312,21 @@ def batched_triton_kernel( def invoke_moe_batched_triton_kernel( - A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - block_shape: Optional[list[int]] = None -): + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) @@ -480,11 +476,12 @@ def prepare( rows = torch.count_nonzero(topks.flatten()) if rows == 0: continue + rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert if self.use_fp8_w8a8: # TODO: use _fp8_quantize - b_a1[idx, :rows, :], tmp_scale = per_token_group_quant_fp8(rhs, block_k) - b_a1_scale[idx, :rows] = tmp_scale # inline? + b_a1[idx, :rows, :], b_a1_scale[ + idx, :rows] = per_token_group_quant_fp8(rhs, block_k) else: b_a1[idx, :rows, :] = rhs @@ -827,12 +824,13 @@ def apply( #assert not self.use_fp8_w8a8 if self.use_fp8_w8a8: per_act_token = False - qintermediate_cache2 = torch.zeros_like(intermediate_cache2, + # TODO: reuse? + qintermediate_cache2 = torch.empty_like(intermediate_cache2, dtype=torch.float8_e4m3fn) block_n = self.block_shape[0] n_tiles = ((N // 2) + block_n - 1) // block_n scale_shape = (E, num_tokens, n_tiles) - a2q_scale = torch.zeros(scale_shape, + a2q_scale = torch.empty(scale_shape, dtype=torch.float32, device=hidden_states.device) for e in range(E): @@ -842,10 +840,10 @@ def apply( # intermediate_cache2[e], # a2_scale[e] if a2_scale is not None else None, # per_act_token, self.block_shape) - qintermediate_cache2[e, :num_tokens, :], tmp_scale = per_token_group_quant_fp8( - intermediate_cache2[e, :num_tokens], block_n) - #print(a2q_scale[e, :tmp_scale.shape[0]].shape) - #print(tmp_scale.shape) + qintermediate_cache2[ + e, : + num_tokens, :], tmp_scale = per_token_group_quant_fp8( + intermediate_cache2[e, :num_tokens], block_n) a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale else: qintermediate_cache2 = intermediate_cache2 From ad55ba15cf2da3d8e4ac5f4b90fd95ef9d08b33e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 28 May 2025 20:19:41 +0000 Subject: [PATCH 07/77] cleanup quantization Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 116 +++++++++++------- vllm/model_executor/layers/fused_moe/layer.py | 15 +-- 3 files changed, 71 insertions(+), 62 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 8abbbb273623..87af80862094 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -217,7 +217,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, @pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("per_act_token_quant", [False]) @pytest.mark.parametrize("block_shape", [None]) def test_fused_moe_batched_experts( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 735f9884cadf..08fc592d8db6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, + moe_kernel_quantize_input) @triton.jit @@ -49,6 +49,7 @@ def moe_mmk( compute_type: tl.constexpr, use_w8a8: tl.constexpr, use_w8a16: tl.constexpr): + offs_k = tl.arange(0, BLOCK_K) if use_w8a16: @@ -327,6 +328,7 @@ def invoke_moe_batched_triton_kernel( use_int4_w4a16: bool, config: dict[str, int], block_shape: Optional[list[int]] = None): + assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) @@ -399,8 +401,9 @@ def __init__( self.dp_size = dp_size self.rank = rank self.max_num_tokens = max_num_tokens - self.use_fp8_w8a8 = use_fp8_w8a8 + self.per_act_token = per_act_token self.block_shape = block_shape + self.qtype = qtype @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -459,7 +462,7 @@ def prepare( dtype=b_type, device=a1.device) - if self.use_fp8_w8a8: + if self.qtype is not None: k_tiles = (hidden_dim + block_k - 1) // block_k b_a1_scale = torch.zeros( (num_local_experts, self.max_num_tokens, k_tiles), @@ -478,10 +481,20 @@ def prepare( continue rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert - if self.use_fp8_w8a8: - # TODO: use _fp8_quantize - b_a1[idx, :rows, :], b_a1_scale[ - idx, :rows] = per_token_group_quant_fp8(rhs, block_k) + if self.qtype is not None: + if a1_scale is not None: + rhs_a1_scale = a1_scale[:topks.numel()][topks] + else: + rhs_a1_scale = None + b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = ( + moe_kernel_quantize_input( + rhs, + rhs_a1_scale, + self.qtype, + self.per_act_token, + self.block_shape, + ) + ) else: b_a1[idx, :rows, :] = rhs @@ -632,6 +645,42 @@ def apply( output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) +def batched_moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + num_tokens: int, + E: int, + N: int, + expert_num_tokens: torch.Tensor, + qtype: Optional[torch.dtype], + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if qtype is not None: + assert block_shape is not None + A_q = torch.empty_like(A, dtype=qtype) + block_n, block_k = block_shape + n_tiles = ((N // 2) + block_n - 1) // block_n + scale_shape = (E, num_tokens, n_tiles) + A_q_scale = torch.empty(scale_shape, + dtype=torch.float32, + device=A.device) + for e in range(E): + num_tokens = expert_num_tokens[e] + if num_tokens > 0: + A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( + A[e, :num_tokens], + A_scale[e, :num_tokens] if A_scale else None, + qtype, + per_channel_quant, + [block_k, block_n]) + A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale + + return A_q, A_q_scale + else: + return A, A_scale + + class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): """ A Triton based MoE expert class that operates on expert batched format, @@ -810,44 +859,17 @@ def apply( self.activation(activation, intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N)) - ic2_hidden_size = intermediate_cache2.size(-1) - intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size) - - qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - A=intermediate_cache2, - A_scale=a2_scale, - quant_dtype=self.quant_dtype, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) - - # TODO (varun) : support w8a8 - #assert not self.use_fp8_w8a8 - if self.use_fp8_w8a8: - per_act_token = False - # TODO: reuse? - qintermediate_cache2 = torch.empty_like(intermediate_cache2, - dtype=torch.float8_e4m3fn) - block_n = self.block_shape[0] - n_tiles = ((N // 2) + block_n - 1) // block_n - scale_shape = (E, num_tokens, n_tiles) - a2q_scale = torch.empty(scale_shape, - dtype=torch.float32, - device=hidden_states.device) - for e in range(E): - num_tokens = expert_num_tokens[e] - if num_tokens > 0: - #qintermediate_cache2[e], tmp_scale = _fp8_quantize( - # intermediate_cache2[e], - # a2_scale[e] if a2_scale is not None else None, - # per_act_token, self.block_shape) - qintermediate_cache2[ - e, : - num_tokens, :], tmp_scale = per_token_group_quant_fp8( - intermediate_cache2[e, :num_tokens], block_n) - a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale - else: - qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale + qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( + intermediate_cache2, + a2_scale, + num_tokens, + E, + N, + expert_num_tokens, + self.qtype, + self.per_act_token, + self.block_shape + ) invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 77e421554823..245c8b82462f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -395,22 +395,10 @@ def set_prepare_finalize( max_num_tokens=MOE_DP_CHUNK_SIZE, world_size=world_size, dp_size=dp_size, - use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, ) else: logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) + experts = TritonExperts() self.fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -773,7 +761,6 @@ def __init__( num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, quant_config=quant_config, ) self.moe_config = moe From 347f58e11aa73a6a97ef122b141b51801f95a5e3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 28 May 2025 23:09:32 +0000 Subject: [PATCH 08/77] merge Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 29 ------------------- .../layers/fused_moe/pplx_prepare_finalize.py | 2 -- .../model_executor/layers/quantization/fp8.py | 2 +- 3 files changed, 1 insertion(+), 32 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 245c8b82462f..37948b83741f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -378,35 +378,6 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: FusedMoEPrepareAndFinalize, - ) -> bool: - assert self.fused_experts == fused_experts - - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - if isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=world_size, - dp_size=dp_size, - ) - else: - logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts() - - self.fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - return True - def forward_cuda( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index d67b3b9f0e58..9b5b56cf4a48 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -166,8 +166,6 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - #print(f"SCALE= {a1q_scale.shape}") - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 926a5f4a7b01..c06b0fe9f36f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,7 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat, From 035d3249b5f8e270d2cb56d5286eb613a717e172 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 28 May 2025 23:29:30 +0000 Subject: [PATCH 09/77] fix merge Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 08fc592d8db6..2500c76de90d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -439,8 +439,6 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - _, block_k = self.block_shape - num_tokens, hidden_dim = a1.size() topk = topk_ids.size(1) @@ -463,6 +461,7 @@ def prepare( device=a1.device) if self.qtype is not None: + _, block_k = self.block_shape k_tiles = (hidden_dim + block_k - 1) // block_k b_a1_scale = torch.zeros( (num_local_experts, self.max_num_tokens, k_tiles), From ac46906709c3f0a04003534a389ca09738cc1f89 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 28 May 2025 23:37:04 +0000 Subject: [PATCH 10/77] lint Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 2500c76de90d..aee7b40c7d9f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, - moe_kernel_quantize_input) + _resize_cache, moe_kernel_quantize_input) @triton.jit @@ -492,8 +491,7 @@ def prepare( self.qtype, self.per_act_token, self.block_shape, - ) - ) + )) else: b_a1[idx, :rows, :] = rhs @@ -669,10 +667,8 @@ def batched_moe_kernel_quantize_input( if num_tokens > 0: A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( A[e, :num_tokens], - A_scale[e, :num_tokens] if A_scale else None, - qtype, - per_channel_quant, - [block_k, block_n]) + A_scale[e, :num_tokens] if A_scale else None, qtype, + per_channel_quant, [block_k, block_n]) A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale @@ -859,16 +855,8 @@ def apply( intermediate_cache1.view(-1, N)) qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, - a2_scale, - num_tokens, - E, - N, - expert_num_tokens, - self.qtype, - self.per_act_token, - self.block_shape - ) + intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens, + self.qtype, self.per_act_token, self.block_shape) invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, From 1a5d6b323c348334482ab4229d57341f6601ef82 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 29 May 2025 02:08:22 +0000 Subject: [PATCH 11/77] fixes Signed-off-by: Bill Nell --- vllm/distributed/device_communicators/all2all.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 85f87cb21edc..9e698819cbb6 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -84,6 +84,9 @@ def __init__(self, cpu_group): ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa super().__init__(cpu_group) + # Intranode doesn't work yet. + self.internode = True + if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly From 29e314c0cc80af00167e349131806ed1387acec6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 29 May 2025 18:50:37 +0000 Subject: [PATCH 12/77] pplx + fp8 test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 186e00800a17..4194c6ad250a 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,6 +28,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.platforms import current_platform from vllm.utils import round_up @@ -419,6 +421,8 @@ def pplx_moe( world_size, rank, dp_size, + quant_dtype=qtype, + block_shape=block_shape, ) experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, @@ -470,7 +474,7 @@ def pplx_moe( w2_scale=w2_scale_chunk, global_num_experts=num_experts) - if use_cudagraphs: + if False and use_cudagraphs: #XXXXXXXXXXXX out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() @@ -606,7 +610,7 @@ def _pplx_moe( @pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index aee7b40c7d9f..0e9add7f0c77 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -467,6 +467,7 @@ def prepare( dtype=torch.float32, device=a1.device) else: + assert a1_scale is None b_a1_scale = None first_expert = num_local_experts * self.rank @@ -830,7 +831,6 @@ def apply( if self.use_fp8_w8a8: intermediate_cache1.fill_(0) - assert not self.use_fp8_w8a8 or a1q_scale is not None # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, From cd5bc8f4821539a862d46d59d0c5f3a5ca7fbbc6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 29 May 2025 21:25:33 +0000 Subject: [PATCH 13/77] fp8 + pplx tests + fixes Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 1 + .../layers/fused_moe/pplx_prepare_finalize.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 0e9add7f0c77..652f4794c893 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -13,6 +13,7 @@ get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) +from vllm.utils import round_up @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 9b5b56cf4a48..25f4f7790c6c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -152,8 +152,10 @@ def prepare( expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: - block_size = (quant_config.block_shape[1] - if quant_config.block_shape is not None else 1) + float32_size = torch.float32.itemsize + block_size = (quant_config.block_shape[1] if quant_config. + block_shape is not None else 1) * float32_size + # zeros? expert_x_scale = torch.empty( (num_local_experts, expert_x.size(1), round_up( From 037eb4a54c67e62375a6016dc8f4d2c113fbec11 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 30 May 2025 00:12:54 +0000 Subject: [PATCH 14/77] re-enable cudagraph+torch.compile Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 4194c6ad250a..9ae02b77f873 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -474,7 +474,7 @@ def pplx_moe( w2_scale=w2_scale_chunk, global_num_experts=num_experts) - if False and use_cudagraphs: #XXXXXXXXXXXX + if use_cudagraphs: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 652f4794c893..144a7ee2bcf8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -655,6 +655,26 @@ def batched_moe_kernel_quantize_input( per_channel_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): + # Note: this does a bunch of extra work because expert_num_tokens is ignored + # but it does support torch.compile + cudagraphs. + hidden_dim = A.size(-1) + if block_shape is not None: + block_shape = [block_shape[1], block_shape[0]] + assert A_scale is None or A_scale.dim() == 2 + A_q, A_q_scale = moe_kernel_quantize_input( + A.view(-1, hidden_dim), + A_scale, + qtype, + per_channel_quant, + block_shape) + A_q = A_q.view(E, -1, hidden_dim) + if A_q_scale is not None: + A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) + return A_q, A_q_scale + + if qtype is not None: assert block_shape is not None A_q = torch.empty_like(A, dtype=qtype) From 43dd36b9133227b4071ab7b175fc4b9dbe994b63 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 30 May 2025 02:33:58 +0000 Subject: [PATCH 15/77] hacks Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 144a7ee2bcf8..539bb5cc820f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -655,7 +655,8 @@ def batched_moe_kernel_quantize_input( per_channel_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (torch.compiler.is_compiling() + if (True or + torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is ignored # but it does support torch.compile + cudagraphs. From a778f5aac14d4c4a2cafd4c66c1da31ddee7cddd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 31 May 2025 01:32:51 +0000 Subject: [PATCH 16/77] clean up quantization parameters Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 3 +- .../layers/fused_moe/fused_batched_moe.py | 30 ++++++++++++------- .../layers/fused_moe/modular_kernel.py | 18 +++++++++++ .../layers/fused_moe/triton_deep_gemm_moe.py | 4 ++- 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 9ae02b77f873..30ab7b8b1275 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -372,7 +372,7 @@ def pplx_moe( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, qtype: Optional[torch.dtype] = None, - per_act_token_quant=False, + per_act_token_quant = False, block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, @@ -422,6 +422,7 @@ def pplx_moe( rank, dp_size, quant_dtype=qtype, + per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 539bb5cc820f..7c273b600502 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, try_get_optimal_moe_config) + get_config_dtype_str, try_get_optimal_moe_config, + get_config_quant_dtype) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) from vllm.utils import round_up @@ -401,9 +402,6 @@ def __init__( self.dp_size = dp_size self.rank = rank self.max_num_tokens = max_num_tokens - self.per_act_token = per_act_token - self.block_shape = block_shape - self.qtype = qtype @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -460,7 +458,7 @@ def prepare( dtype=b_type, device=a1.device) - if self.qtype is not None: + if self.quant_dtype is not None: _, block_k = self.block_shape k_tiles = (hidden_dim + block_k - 1) // block_k b_a1_scale = torch.zeros( @@ -481,7 +479,7 @@ def prepare( continue rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert - if self.qtype is not None: + if self.quant_dtype is not None: if a1_scale is not None: rhs_a1_scale = a1_scale[:topks.numel()][topks] else: @@ -490,8 +488,8 @@ def prepare( moe_kernel_quantize_input( rhs, rhs_a1_scale, - self.qtype, - self.per_act_token, + self.quant_dtype, + self.per_act_token_quant, self.block_shape, )) else: @@ -561,11 +559,21 @@ def __init__( assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + quant_dtype = get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ) + super().__init__( + quant_dtype=quant_dtype, + per_act_token_quant=False, # TODO (bnell): quantization + block_shape=block_shape, + ) + assert block_m is None self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - self.use_fp8_w8a8 = use_fp8_w8a8 - self.block_shape = block_shape @property def activation_formats( @@ -878,7 +886,7 @@ def apply( qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens, - self.qtype, self.per_act_token, self.block_shape) + self.quant_dtype, self.per_act_token_quant, self.block_shape) invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2ffb4d328eca..66fcea51b41a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -101,6 +101,15 @@ class FusedMoEPrepareAndFinalize(ABC): An abstract base class for the [Quantize-Prepare] and [Finalize] steps described above. """ + def __init__( + self, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], + ): + self.quant_dtype = quant_dtype + self.per_act_token_quant = per_act_token_quant + self.block_shape = block_shape @abstractmethod def prepare( @@ -199,6 +208,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): An abstract base class for the [Permute-Experts-Unpermute] step described above. """ + def __init__( + self, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], + ): + self.quant_dtype = quant_dtype + self.per_act_token_quant = per_act_token_quant + self.block_shape = block_shape def __init__( self, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index e660376ebe6b..f1ac12e9799f 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -8,7 +8,9 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) -from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_config_quant_dtype, + TritonExperts) class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): From 0ec77aa68011bd1230a3e9c24e3367cd593a87b6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 31 May 2025 01:51:11 +0000 Subject: [PATCH 17/77] lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 26 +++++++------------ .../layers/fused_moe/modular_kernel.py | 2 ++ .../layers/fused_moe/triton_deep_gemm_moe.py | 3 +-- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 30ab7b8b1275..4bafcbeffefc 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -372,7 +372,7 @@ def pplx_moe( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, qtype: Optional[torch.dtype] = None, - per_act_token_quant = False, + per_act_token_quant=False, block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7c273b600502..1c4192fdad7d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -10,11 +10,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, try_get_optimal_moe_config, - get_config_quant_dtype) + get_config_dtype_str, get_config_quant_dtype, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) -from vllm.utils import round_up @triton.jit @@ -567,7 +565,7 @@ def __init__( ) super().__init__( quant_dtype=quant_dtype, - per_act_token_quant=False, # TODO (bnell): quantization + per_act_token_quant=False, # TODO (bnell): quantization block_shape=block_shape, ) assert block_m is None @@ -663,27 +661,23 @@ def batched_moe_kernel_quantize_input( per_channel_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (True or - torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): - # Note: this does a bunch of extra work because expert_num_tokens is ignored - # but it does support torch.compile + cudagraphs. + if (True or torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): + # Note: this does a bunch of extra work because expert_num_tokens is + # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) if block_shape is not None: block_shape = [block_shape[1], block_shape[0]] assert A_scale is None or A_scale.dim() == 2 - A_q, A_q_scale = moe_kernel_quantize_input( - A.view(-1, hidden_dim), - A_scale, - qtype, - per_channel_quant, - block_shape) + A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, + hidden_dim), A_scale, + qtype, per_channel_quant, + block_shape) A_q = A_q.view(E, -1, hidden_dim) if A_q_scale is not None: A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) return A_q, A_q_scale - if qtype is not None: assert block_shape is not None A_q = torch.empty_like(A, dtype=qtype) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 66fcea51b41a..c1b8b2317c34 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -101,6 +101,7 @@ class FusedMoEPrepareAndFinalize(ABC): An abstract base class for the [Quantize-Prepare] and [Finalize] steps described above. """ + def __init__( self, quant_dtype: Optional[torch.dtype], @@ -208,6 +209,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): An abstract base class for the [Permute-Experts-Unpermute] step described above. """ + def __init__( self, quant_dtype: Optional[torch.dtype], diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index f1ac12e9799f..512f0048cbfe 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -9,8 +9,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_quant_dtype, - TritonExperts) + TritonExperts, get_config_quant_dtype) class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): From c6a44511f0500a1e2011fbfc915cf4369f03b4f2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 1 Jun 2025 19:43:11 +0000 Subject: [PATCH 18/77] progress on grouped quant for batched experts Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 2 + tests/kernels/moe/test_pplx_moe.py | 10 +++ .../layers/fused_moe/fused_batched_moe.py | 80 ++++++++++++++----- .../layers/fused_moe/pplx_prepare_finalize.py | 36 +++++++-- .../layers/fused_moe/triton_deep_gemm_moe.py | 4 +- vllm/model_executor/models/qwen3_moe.py | 2 + 6 files changed, 106 insertions(+), 28 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 87af80862094..016955809313 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -16,6 +16,8 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 4bafcbeffefc..f4d43ed73ac9 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,6 +28,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform @@ -35,11 +37,19 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch +from tests.kernels.moe.utils import ( + native_w8a8_block_matmul, + torch_moe2, + naive_batched_moe, +) + + requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", ) + PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), (222, 2048, 1024)] diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 1c4192fdad7d..cbfc126b9a5a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -29,8 +29,9 @@ def moe_mmk( # (A has M rows). stride_ak, stride_bk, - stride_asm, + stride_ase, stride_ask, + stride_asm, stride_bse, stride_bsk, stride_bsn, @@ -47,7 +48,9 @@ def moe_mmk( BLOCK_K: tl.constexpr, compute_type: tl.constexpr, use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr): + use_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, +): offs_k = tl.arange(0, BLOCK_K) @@ -59,10 +62,19 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + a_scale_ptrs = a_scale_ptr + expert_id * stride_ase + offs_m * stride_asm offs_bsn = offs_n // group_n b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + offs_bsn * stride_bsn) + + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None] + # tensor-wise else: a_scale = tl.load(a_scale_ptr) @@ -142,8 +154,9 @@ def expert_triton_kernel( stride_bn, stride_cm, stride_cn, - stride_asm, + stride_ase, stride_ask, + stride_asm, stride_bse, stride_bsk, stride_bsn, @@ -153,10 +166,12 @@ def expert_triton_kernel( # Quantization schemes use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, # Kernel config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr): + BLOCK_K: tl.constexpr, +): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) % N @@ -179,8 +194,9 @@ def expert_triton_kernel( # (A has M rows). stride_ak, stride_bk, - stride_asm, + stride_ase, stride_ask, + stride_asm, stride_bse, stride_bsk, stride_bsn, @@ -197,7 +213,8 @@ def expert_triton_kernel( BLOCK_K, compute_type, use_fp8_w8a8, - use_int8_w8a16) + use_int8_w8a16, + per_channel_quant) # store in C offs_cn = tl.arange(0, BLOCK_N) @@ -234,8 +251,9 @@ def batched_triton_kernel( stride_ce, stride_cm, stride_cn, - stride_asm, + stride_ase, stride_ask, + stride_asm, stride_bse, stride_bsk, stride_bsn, @@ -245,6 +263,7 @@ def batched_triton_kernel( # Quantization schemes use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, # Kernel config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -294,8 +313,9 @@ def batched_triton_kernel( stride_bn, stride_cm, stride_cn, - stride_asm, + stride_ase, stride_ask, + stride_asm, stride_bse, stride_bsk, stride_bsn, @@ -305,6 +325,7 @@ def batched_triton_kernel( # Quantization schemes use_fp8_w8a8, use_int8_w8a16, + per_channel_quant, # Kernel config BLOCK_M, BLOCK_N, @@ -326,6 +347,7 @@ def invoke_moe_batched_triton_kernel( use_int8_w8a16: bool, use_int4_w4a16: bool, config: dict[str, int], + per_act_token_quant: bool, block_shape: Optional[list[int]] = None): assert not use_int4_w4a16 @@ -364,8 +386,9 @@ def invoke_moe_batched_triton_kernel( C.stride(0), C.stride(1), C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim >= 2 else 0, + A_scale.stride(2) if A_scale is not None and A_scale.ndim == 3 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim >= 2 else 0, B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, @@ -375,6 +398,7 @@ def invoke_moe_batched_triton_kernel( # Quantization schemes use_fp8_w8a8, use_int8_w8a16, + per_act_token_quant, # Kernel config BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -457,10 +481,18 @@ def prepare( device=a1.device) if self.quant_dtype is not None: - _, block_k = self.block_shape - k_tiles = (hidden_dim + block_k - 1) // block_k + if self.block_shape is not None: + _, block_k = self.block_shape + k_tiles = (hidden_dim + block_k - 1) // block_k + scale_shape = (num_local_experts, self.max_num_tokens, k_tiles) + else: + num = self.max_num_tokens if self.per_act_token_quant else 1 + scale_shape = (num_local_experts, num, 1) + + print(f"SCALE_SHAPE {b_a1.shape} {scale_shape}") + b_a1_scale = torch.zeros( - (num_local_experts, self.max_num_tokens, k_tiles), + scale_shape, dtype=torch.float32, device=a1.device) else: @@ -479,10 +511,11 @@ def prepare( idx = expert_id - first_expert if self.quant_dtype is not None: if a1_scale is not None: + assert False, "NYI" rhs_a1_scale = a1_scale[:topks.numel()][topks] else: rhs_a1_scale = None - b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = ( + b_a1[idx, :rows, :], b_s = ( moe_kernel_quantize_input( rhs, rhs_a1_scale, @@ -490,6 +523,10 @@ def prepare( self.per_act_token_quant, self.block_shape, )) + if self.block_shape is None and not self.per_act_token_quant: + b_a1_scale[idx] = b_s + else: + b_a1_scale[idx, :rows] = b_s else: b_a1[idx, :rows, :] = rhs @@ -565,7 +602,7 @@ def __init__( ) super().__init__( quant_dtype=quant_dtype, - per_act_token_quant=False, # TODO (bnell): quantization + per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) assert block_m is None @@ -666,8 +703,6 @@ def batched_moe_kernel_quantize_input( # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) - if block_shape is not None: - block_shape = [block_shape[1], block_shape[0]] assert A_scale is None or A_scale.dim() == 2 A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, hidden_dim), A_scale, @@ -675,7 +710,10 @@ def batched_moe_kernel_quantize_input( block_shape) A_q = A_q.view(E, -1, hidden_dim) if A_q_scale is not None: - A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) + if A_q_scale.ndim == 1: + A_q_scale = torch.repeat_interleave(A_q_scale, E, dim=0).view(E, 1, 1) + else: + A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) return A_q, A_q_scale if qtype is not None: @@ -693,7 +731,7 @@ def batched_moe_kernel_quantize_input( A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( A[e, :num_tokens], A_scale[e, :num_tokens] if A_scale else None, qtype, - per_channel_quant, [block_k, block_n]) + per_channel_quant, block_shape) A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale @@ -869,6 +907,7 @@ def apply( use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, config=config, + per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) intermediate_cache2.fill_(0) @@ -894,4 +933,5 @@ def apply( use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, config=config, + per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 25f4f7790c6c..7726f2a92b51 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -112,6 +112,7 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) + repeat_cols = 4 repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) a1q, a1q_scale = moe_kernel_quantize_input( @@ -127,6 +128,14 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) + # per_act_token_quant = a1_scale.numel() != 1 if a1_scale is not None else ( + # a2_scale.numel() != 1 if a2_scale is not None else False) + + # a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + # self.quant_dtype, + # per_act_token, + # self.block_shape) + if a1q_scale is not None and a1q_scale.dim() == 1: assert a1q_scale.numel() == 1 a1q_scale = a1q_scale.view(1, 1) @@ -155,15 +164,24 @@ def prepare( float32_size = torch.float32.itemsize block_size = (quant_config.block_shape[1] if quant_config. block_shape is not None else 1) * float32_size - # zeros? - expert_x_scale = torch.empty( - (num_local_experts, expert_x.size(1), - round_up( - (expert_x.size(2) + block_size - 1) // block_size, 4)), + + + expert_x_scale_shape = ( + num_local_experts, + expert_x.size(1), + (expert_x.size(2) + block_size - 1) // block_size, + ) + + print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}") + + expert_x_scale = torch.zeros( + expert_x_scale_shape, dtype=torch.float32, - device=device, + device=expert_x.device, ) + print(f"YYYYYYYYYYYYYYY {expert_x.shape}") + # This argument is optional, defaults to indices.size(0) # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None @@ -180,6 +198,10 @@ def prepare( if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] + print(f"ZZZZZZZZZZZZZZ") + if expert_x_scale is not None: + expert_x_scale = expert_x_scale[:, :, 0:1] + return expert_x, expert_x_scale, expert_num_tokens, None, None def finalize( @@ -205,6 +227,8 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) + print("CCCCCCCCCCCCCCCCCCCC") + self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 512f0048cbfe..4a1c7d4be1ba 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -81,8 +81,8 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - assert self.deep_gemm_expert is not None + if (self.allow_deep_gemm and N > 512 + and _valid_deep_gemm_shape(M, N, K)): return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 90a28192eccb..aedfad56d031 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -110,6 +110,8 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") + logger.info("MoE quant config %s", quant_config.__dict__) + self.experts = FusedMoE(num_experts=config.num_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, From faa9b2f7e3ebf6ab3556c1f2f46ee65724f3ba3e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 2 Jun 2025 16:16:38 +0000 Subject: [PATCH 19/77] wip Signed-off-by: Bill Nell --- .../device_communicators/cuda_communicator.py | 2 +- .../layers/fused_moe/cutlass_moe.py | 5 +++ .../layers/fused_moe/fused_batched_moe.py | 11 ++++-- .../layers/fused_moe/pplx_prepare_finalize.py | 8 ++-- .../compressed_tensors_moe.py | 38 ++++++++++++++++++- 5 files changed, 54 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 3958d566b174..4071802e5288 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -71,7 +71,7 @@ def __init__(self, device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend == "naive": + if all2all_backend == "naive" or len(all2all_backend) == 0: from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0ef4e4f767e3..f334a4d03cb1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -204,6 +204,8 @@ def run_cutlass_moe_fp8( # TODO (bnell): split class batched vs. non-batched? # maybe remove need for passing aq to workspace_shapes +class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): + class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -360,6 +362,9 @@ def cutlass_moe_fp8( num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( 0) + if out_dtype is None: + out_dtype = a.dtype + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index cbfc126b9a5a..7d606dabdb9e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -62,13 +62,15 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + expert_id * stride_ase + offs_m * stride_asm + # XXXXXXX + a_scale_ptrs = a_scale_ptr + (expert_id * stride_ase) + (offs_m * stride_asm) offs_bsn = offs_n // group_n b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + offs_bsn * stride_bsn) # channel-wise elif per_channel_quant: + # TODO: probably not correct b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations @@ -77,7 +79,7 @@ def moe_mmk( # tensor-wise else: - a_scale = tl.load(a_scale_ptr) + a_scale = tl.load(a_scale_ptr) # + expert_id) #? b_scale = tl.load(b_scale_ptr + expert_id) # ----------------------------------------------------------- @@ -703,14 +705,15 @@ def batched_moe_kernel_quantize_input( # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) - assert A_scale is None or A_scale.dim() == 2 + assert A_scale is None or A_scale.ndim <= 2 A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, hidden_dim), A_scale, qtype, per_channel_quant, block_shape) A_q = A_q.view(E, -1, hidden_dim) if A_q_scale is not None: - if A_q_scale.ndim == 1: + if A_q_scale.numel() == 1: + A_q_scale = A_q_scale.view(1) A_q_scale = torch.repeat_interleave(A_q_scale, E, dim=0).view(E, 1, 1) else: A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 7726f2a92b51..73986b25c89c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -172,7 +172,7 @@ def prepare( (expert_x.size(2) + block_size - 1) // block_size, ) - print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}") + #print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}") expert_x_scale = torch.zeros( expert_x_scale_shape, @@ -180,7 +180,7 @@ def prepare( device=expert_x.device, ) - print(f"YYYYYYYYYYYYYYY {expert_x.shape}") + #print(f"YYYYYYYYYYYYYYY {expert_x.shape}") # This argument is optional, defaults to indices.size(0) # There's not much point setting this unless it is != indices.size(0) @@ -198,7 +198,7 @@ def prepare( if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] - print(f"ZZZZZZZZZZZZZZ") + #print(f"ZZZZZZZZZZZZZZ") if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, 0:1] @@ -227,7 +227,7 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - print("CCCCCCCCCCCCCCCCCCCC") + #print("CCCCCCCCCCCCCCCCCCCC") self.a2a.combine(out_tokens=output, indices=topk_ids, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fa011266cf2f..a0d66c65ed34 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,6 +3,9 @@ import enum from enum import Enum + +import functools + from typing import Callable, Optional import torch @@ -12,6 +15,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.distributed import get_ep_group from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, @@ -398,6 +402,8 @@ def __init__( self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.use_pplx_kernels = False + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -572,6 +578,35 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: self.fused_experts_func = fused_experts + # XXXXXXXXXX + def select_gemm_impl(self, prepare_finalize): + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize, BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + assert not self.rocm_aiter_moe_enabled and not self.use_marlin + + assert isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)) + + logger.debug("BatchedTritonExperts(%s)", self.__classname__.__name__) + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + self.use_pplx_kernels = True + return BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=all2all_manager.world_size, + dp_size=all2all_manager.tp_group.world_size, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN + ), + ) + def apply( self, layer: torch.nn.Module, @@ -609,7 +644,8 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=torch.uint32 if self.use_pplx_kernels else None) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts_func( From 985ce2eeb96c6489376d651bc8c82bb034e80128 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 2 Jun 2025 21:32:45 +0000 Subject: [PATCH 20/77] triton + debug hacking Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 5 +++ .../layers/fused_moe/fused_batched_moe.py | 39 ++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 016955809313..b8070cb39fc7 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -298,6 +298,11 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) + # torch.testing.assert_close(triton_output, + # baseline_output, + # atol=2e-2, + # rtol=6e-2) + torch.testing.assert_close(triton_output, baseline_output, atol=2e-2, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7d606dabdb9e..9d12b7cd2c75 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -62,8 +62,8 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - # XXXXXXX - a_scale_ptrs = a_scale_ptr + (expert_id * stride_ase) + (offs_m * stride_asm) + # + (expert_id * stride_ase) ?? + a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm) offs_bsn = offs_n // group_n b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + offs_bsn * stride_bsn) @@ -74,12 +74,13 @@ def moe_mmk( b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations + # + (expert_id * stride_ase)?? a_scale_ptrs = a_scale_ptr + offs_m * stride_asm a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None] # tensor-wise else: - a_scale = tl.load(a_scale_ptr) # + expert_id) #? + a_scale = tl.load(a_scale_ptr) #+ expert_id * stride_ase ? b_scale = tl.load(b_scale_ptr + expert_id) # ----------------------------------------------------------- @@ -296,6 +297,14 @@ def batched_triton_kernel( c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + cta_n_start * stride_cn) + if use_fp8_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptr = a_scale_ptr + expert_id * stride_ase + cta_m_start * stride_asm + # channel-wise + elif per_channel_quant: + a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + expert_triton_kernel( a_ptr, b_ptr, @@ -388,12 +397,15 @@ def invoke_moe_batched_triton_kernel( C.stride(0), C.stride(1), C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim >= 2 else 0, - A_scale.stride(2) if A_scale is not None and A_scale.ndim == 3 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim >= 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + + A_scale.stride(0) if A_scale is not None and A_scale.ndim >= 2 else 0, #E + A_scale.stride(2) if A_scale is not None and A_scale.ndim == 3 else 0, #K + A_scale.stride(1) if A_scale is not None and A_scale.ndim >= 2 else 0, #M + + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, #E + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, #K + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, #N + # Blockwise quantization data 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], @@ -491,7 +503,7 @@ def prepare( num = self.max_num_tokens if self.per_act_token_quant else 1 scale_shape = (num_local_experts, num, 1) - print(f"SCALE_SHAPE {b_a1.shape} {scale_shape}") + #print(f"SCALE_SHAPE {self.block_shape} {b_a1.shape} {scale_shape}") b_a1_scale = torch.zeros( scale_shape, @@ -528,6 +540,8 @@ def prepare( if self.block_shape is None and not self.per_act_token_quant: b_a1_scale[idx] = b_s else: + #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") + assert rows == b_s.shape[0] and b_a1_scale.shape[-1] == b_s.shape[-1] b_a1_scale[idx, :rows] = b_s else: b_a1[idx, :rows, :] = rhs @@ -717,6 +731,11 @@ def batched_moe_kernel_quantize_input( A_q_scale = torch.repeat_interleave(A_q_scale, E, dim=0).view(E, 1, 1) else: A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) + + #print(f"A2Q_SCALE {A_q_scale.shape}") + #A_q_scale.fill_(0.0001) + #print(f"A_q_scale.stride = {A_q_scale.stride()}") + return A_q, A_q_scale if qtype is not None: From 4d114ee917e889e9c02e4828224ab4a4dac21e41 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 3 Jun 2025 22:02:05 +0000 Subject: [PATCH 21/77] batched mm tests with real scales + grouped quant Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 20 ++--- tests/kernels/moe/test_pplx_moe.py | 1 - tests/kernels/moe/utils.py | 1 + .../layers/fused_moe/fused_batched_moe.py | 33 +++++-- .../layers/fused_moe/fused_moe.py | 90 +++++++++++++++++++ .../layers/quantization/utils/fp8_utils.py | 2 +- 6 files changed, 128 insertions(+), 19 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index b8070cb39fc7..8b44c8a10d43 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -105,8 +105,8 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("block_shape", [None]) -@pytest.mark.parametrize("per_act_token_quant", [False]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype, block_shape: Optional[list[int]], @@ -202,10 +202,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, A_scale, B_scale, block_shape) - ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2, - tensors.num_expert_tokens, A_scale, B_scale, - block_shape[-2:]) - rtol, atol = { torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), @@ -298,10 +294,14 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - # torch.testing.assert_close(triton_output, - # baseline_output, - # atol=2e-2, - # rtol=6e-2) + torch.testing.assert_close(triton_output, + baseline_output, + atol=2e-2, + rtol=2e-2) + + #print(f"TORCH {baseline_output.shape}\n{baseline_output}") + #print(f"TRITON {triton_output.shape}\n{triton_output}") + #print(f"BATCHED {batched_output.shape}\n{batched_output}") torch.testing.assert_close(triton_output, baseline_output, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f4d43ed73ac9..4e513c813c02 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -38,7 +38,6 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch from tests.kernels.moe.utils import ( - native_w8a8_block_matmul, torch_moe2, naive_batched_moe, ) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 5b1048797447..75915457896b 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -16,6 +16,7 @@ moe_kernel_quantize_input) from vllm.utils import round_up +from tests.kernels.quant_utils import native_w8a8_block_matmul def triton_moe( a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9d12b7cd2c75..f50ba180e7e4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -300,7 +300,7 @@ def batched_triton_kernel( if use_fp8_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptr = a_scale_ptr + expert_id * stride_ase + cta_m_start * stride_asm + a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + cta_m_start * stride_asm # channel-wise elif per_channel_quant: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) @@ -714,8 +714,9 @@ def batched_moe_kernel_quantize_input( per_channel_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (True or torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): + if (True or + torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) @@ -725,6 +726,11 @@ def batched_moe_kernel_quantize_input( qtype, per_channel_quant, block_shape) A_q = A_q.view(E, -1, hidden_dim) + + # for e in range(len(expert_num_tokens)): + # num = expert_num_tokens[e] + # A_q_scale[e, num:].fill_(0) + if A_q_scale is not None: if A_q_scale.numel() == 1: A_q_scale = A_q_scale.view(1) @@ -732,7 +738,7 @@ def batched_moe_kernel_quantize_input( else: A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) - #print(f"A2Q_SCALE {A_q_scale.shape}") + #print(f"A2Q_SCALE {A_q_scale.shape}\n{A_q_scale}") #A_q_scale.fill_(0.0001) #print(f"A_q_scale.stride = {A_q_scale.stride()}") @@ -744,7 +750,7 @@ def batched_moe_kernel_quantize_input( block_n, block_k = block_shape n_tiles = ((N // 2) + block_n - 1) // block_n scale_shape = (E, num_tokens, n_tiles) - A_q_scale = torch.empty(scale_shape, + A_q_scale = torch.zeros(scale_shape, dtype=torch.float32, device=A.device) for e in range(E): @@ -756,6 +762,15 @@ def batched_moe_kernel_quantize_input( per_channel_quant, block_shape) A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale + if A_q_scale is not None: + if A_q_scale.numel() == 1: + A_q_scale = A_q_scale.view(1) + A_q_scale = torch.repeat_interleave(A_q_scale, E, dim=0).view(E, 1, 1) + else: + A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) + + #print(f"A2Q_SCALE {A_q_scale.shape}\n{A_q_scale}") + return A_q, A_q_scale else: return A, A_scale @@ -936,8 +951,12 @@ def apply( # TODO: would be nice to use expert_num_tokens here to reduce # garbage compute - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + self.activation( + activation, + intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) + + #print(f"BATCHED ACT {intermediate_cache2.shape}\n{intermediate_cache2}") qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 75712b8e3a4d..25a332dc4f04 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -466,6 +466,74 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) + +def prepare_scales( + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + quant_dtype: Optional[torch.dtype], + block_shape: Optional[list[int]], + msg: str, +): + from vllm.utils import round_up + max_num_tokens = round_up(a1.shape[0], 64) + num_tokens, hidden_dim = a1.size() + topk = topk_ids.size(1) + + tokens_per_expert = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) + + num_local_experts = num_experts + + b_a1 = torch.zeros( + (num_local_experts, max_num_tokens, hidden_dim), + dtype=quant_dtype + if quant_dtype is not None else a1.dtype, + device=a1.device) + + if quant_dtype is not None: + if block_shape is not None: + _, block_k = block_shape + k_tiles = (hidden_dim + block_k - 1) // block_k + scale_shape = (num_local_experts, max_num_tokens, k_tiles) + else: + num = 1 + scale_shape = (num_local_experts, num, 1) + + b_a1_scale = torch.zeros( + scale_shape, + dtype=torch.float32, + device=a1.device) + else: + assert a1_scale is None + b_a1_scale = None + + first_expert = 0 + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + topks = torch.any(topk_ids == expert_id, dim=1).flatten() + rows = torch.count_nonzero(topks.flatten()) + rhs = a1[:topks.numel()][topks] + idx = expert_id - first_expert + b_a1[idx, :rows, :] = rhs + if quant_dtype is not None: + rhs_a1_scale = a1_scale[:topks.numel()][topks] + if block_shape is None: + b_a1_scale[idx] = rhs_a1_scale + else: + assert rows == rhs_a1_scale.shape[0] and b_a1_scale.shape[-1] == rhs_a1_scale.shape[-1] + b_a1_scale[idx, :rows] = rhs_a1_scale + + tokens_per_expert[idx] = rows + + print(f"{msg} {b_a1_scale.shape}\n{b_a1_scale}") + + return b_a1, b_a1_scale, tokens_per_expert + + def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @@ -1330,6 +1398,17 @@ def fused_experts_impl( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) + if False: + prepare_scales( + qcurr_hidden_states, + a1q_scale, + curr_topk_ids, + global_num_experts, + torch.float8_e4m3fn if use_fp8_w8a8 else None, + block_shape, + "First", + ) + invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, @@ -1367,6 +1446,17 @@ def fused_experts_impl( per_act_token_quant=per_channel_quant, block_shape=block_shape) + if False: + prepare_scales( + qintermediate_cache2, + a2q_scale, + curr_topk_ids, + global_num_experts, + torch.float8_e4m3fn if use_fp8_w8a8 else None, + block_shape, + "Second", + ) + invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index c38a445c571b..68b162331da5 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -593,7 +593,7 @@ def w8a8_block_fp8_matmul( assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], f"{block_k}, {triton.cdiv(A.shape[-1], block_k)} == As.shape[-1]" M = A.numel() // A.shape[-1] assert B.ndim == 2 and Bs.ndim == 2 From 8a019e2e9b052fe91f4c3630553dfd339fb5eff7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 4 Jun 2025 15:19:30 +0000 Subject: [PATCH 22/77] hacking on tests Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index f50ba180e7e4..2e5917cd2e2d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -63,10 +63,9 @@ def moe_mmk( # block-wise if group_k > 0 and group_n > 0: # + (expert_id * stride_ase) ?? - a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm) + a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm) #+ (expert_id * stride_ase) offs_bsn = offs_n // group_n - b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + - offs_bsn * stride_bsn) + b_scale_ptrs = (b_scale_ptr + offs_bsn * stride_bsn) + expert_id * stride_bse # channel-wise elif per_channel_quant: @@ -80,8 +79,8 @@ def moe_mmk( # tensor-wise else: - a_scale = tl.load(a_scale_ptr) #+ expert_id * stride_ase ? - b_scale = tl.load(b_scale_ptr + expert_id) + a_scale = tl.load(a_scale_ptr)# + (expert_id * stride_ase) + b_scale = tl.load(b_scale_ptr + expert_id * stride_bse) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -301,9 +300,11 @@ def batched_triton_kernel( # block-wise if group_k > 0 and group_n > 0: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + cta_m_start * stride_asm + #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) # + cta_n_start * stride_bsn? # channel-wise elif per_channel_quant: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) expert_triton_kernel( a_ptr, From dceee159b413de509df19a4ead751c7445434074 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 4 Jun 2025 21:34:41 +0000 Subject: [PATCH 23/77] scale hacking Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 1 + .../device_communicators/all2all.py | 3 - .../layers/fused_moe/fused_batched_moe.py | 67 ++++++++++++++----- .../layers/fused_moe/pplx_prepare_finalize.py | 50 +++++++++----- .../model_executor/layers/quantization/fp8.py | 1 + vllm/model_executor/models/granitemoe.py | 7 ++ vllm/model_executor/models/qwen3_moe.py | 2 +- 7 files changed, 93 insertions(+), 38 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 8b44c8a10d43..9eec7cbd0d14 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -184,6 +184,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 }, + per_act_token_quant=False, block_shape=block_shape, ) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 9e698819cbb6..85f87cb21edc 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -84,9 +84,6 @@ def __init__(self, cpu_group): ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa super().__init__(cpu_group) - # Intranode doesn't work yet. - self.internode = True - if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 2e5917cd2e2d..a2e7032d39b8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -30,8 +30,8 @@ def moe_mmk( stride_ak, stride_bk, stride_ase, - stride_ask, stride_asm, + stride_ask, stride_bse, stride_bsk, stride_bsn, @@ -157,8 +157,8 @@ def expert_triton_kernel( stride_cm, stride_cn, stride_ase, - stride_ask, stride_asm, + stride_ask, stride_bse, stride_bsk, stride_bsn, @@ -197,8 +197,8 @@ def expert_triton_kernel( stride_ak, stride_bk, stride_ase, - stride_ask, stride_asm, + stride_ask, stride_bse, stride_bsk, stride_bsn, @@ -254,8 +254,8 @@ def batched_triton_kernel( stride_cm, stride_cn, stride_ase, - stride_ask, stride_asm, + stride_ask, stride_bse, stride_bsk, stride_bsn, @@ -298,11 +298,11 @@ def batched_triton_kernel( if use_fp8_w8a8: # block-wise - if group_k > 0 and group_n > 0: + if (group_k > 0 and group_n > 0) or per_channel_quant: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + cta_m_start * stride_asm #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) # + cta_n_start * stride_bsn? - # channel-wise - elif per_channel_quant: + # channel-wise or tensor-wise + else: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) @@ -326,8 +326,8 @@ def batched_triton_kernel( stride_cm, stride_cn, stride_ase, - stride_ask, stride_asm, + stride_ask, stride_bse, stride_bsk, stride_bsn, @@ -374,6 +374,36 @@ def invoke_moe_batched_triton_kernel( grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) + assert A_scale is None or A_scale.ndim == 1 or A_scale.ndim == 3, f"{0 if A_scale is None else (A_scale.ndim, A_scale.shape)}" + assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else (B_scale.ndim, B_scale.shape)}" + + #print(f"SCALES {A_scale.shape}, {B_scale.shape}") + + stride_bse = 0 + stride_bsk = 0 + stride_bsn = 0 + if B_scale is not None: + if B_scale.ndim == 1: + stride_bsk = B_scale.stride(0) + else: + assert B_scale.ndim == 3 + stride_bse = B_scale.stride(0) + stride_bsn = B_scale.stride(1) + stride_bsk = B_scale.stride(2) + + stride_ase = 0 + stride_asm = 0 + stride_ask = 0 + if A_scale is not None: + if A_scale.ndim == 1: + stride_ask = A_scale.stride(0) + else: + assert A_scale.ndim == 3 + stride_ase = A_scale.stride(0) + stride_asm = A_scale.stride(1) + stride_ask = A_scale.stride(2) + + batched_triton_kernel[grid]( A, B, @@ -398,15 +428,12 @@ def invoke_moe_batched_triton_kernel( C.stride(0), C.stride(1), C.stride(2), - - A_scale.stride(0) if A_scale is not None and A_scale.ndim >= 2 else 0, #E - A_scale.stride(2) if A_scale is not None and A_scale.ndim == 3 else 0, #K - A_scale.stride(1) if A_scale is not None and A_scale.ndim >= 2 else 0, #M - - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, #E - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, #K - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, #N - + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, # Blockwise quantization data 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], @@ -549,7 +576,11 @@ def prepare( tokens_per_expert[idx] = rows - return b_a1, b_a1_scale, tokens_per_expert, None, None + #b_a1_scale.fill_(0.0001) + #print(f"A1Q_scale = {b_a1_scale.shape}\n{b_a1_scale}") + assert b_a1_scale is None or b_a1_scale.ndim == 3 + + return b_a1, b_a1_scale, tokens_per_expert def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 73986b25c89c..643f4ab00c68 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -121,24 +121,28 @@ def prepare( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) + # pplx requires 2-d scales even for scalars if a1q_scale is not None: + if a1q_scale.dim() <= 1: + assert a1q_scale.numel() == 1 + a1q_scale = a1q_scale.view(1, 1) + + #print(f"ORIG {a1q_scale.shape}, {a1q_scale}") + + orig_scale = a1q_scale + orig_a1q_scale_shape = a1q_scale.shape + + # pad out scales if needed if a1q_scale.numel() == 1: - orig_a_scale_block_shape = 1 - else: - orig_a_scale_block_shape = a1q_scale.shape[-1] - a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) + a1q_scale = a1q_scale.repeat(a1q.shape[1], 4) + + assert a1q_scale.shape[0] == a1q.shape[1] - # per_act_token_quant = a1_scale.numel() != 1 if a1_scale is not None else ( - # a2_scale.numel() != 1 if a2_scale is not None else False) + #print(f"FINAL {a1q_scale.shape}, {a1q_scale}") - # a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, - # self.quant_dtype, - # per_act_token, - # self.block_shape) - if a1q_scale is not None and a1q_scale.dim() == 1: - assert a1q_scale.numel() == 1 - a1q_scale = a1q_scale.view(1, 1) + assert a1q_scale is None or a1q_scale.ndim == 2, \ + f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" # rem_experts need to be 0 for pplx to work properly. rem_experts = num_experts % self.world_size @@ -169,7 +173,8 @@ def prepare( expert_x_scale_shape = ( num_local_experts, expert_x.size(1), - (expert_x.size(2) + block_size - 1) // block_size, + #(expert_x.size(2) + block_size - 1) // block_size, + orig_a1q_scale_shape[-1], ) #print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}") @@ -198,9 +203,22 @@ def prepare( if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] - #print(f"ZZZZZZZZZZZZZZ") + #print(f"ZZZZZZZZZZZZZZ {expert_x_scale.shape}") if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, 0:1] + expert_x_scale = expert_x_scale[:, :, :orig_a1q_scale_shape[-1]] + from math import prod + if prod(orig_a1q_scale_shape) == 1: + expert_x_scale = expert_x_scale[:, :1, :1] + #print(f"EPT {expert_num_tokens.flatten()}") + #print(f"SCALARIZING!!! {expert_x_scale.shape}, {expert_x_scale.flatten()}") + idx = expert_num_tokens.flatten() != 0 + assert torch.all(expert_x_scale.flatten()[idx] != 0) + #zidx = expert_num_tokens.flatten() == 0 + #assert torch.all(expert_x_scale.flatten()[zidx] == 0) + assert expert_x_scale.ndim == 3 + #expert_x_scale = orig_scale.view(1) + + assert expert_x_scale.ndim == 1 or expert_x_scale.ndim == 3 return expert_x, expert_x_scale, expert_num_tokens, None, None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c06b0fe9f36f..f662b755fcba 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -820,6 +820,7 @@ def select_gemm_impl( return TritonOrDeepGemmExperts( use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, + per_act_token=False, #? allow_deep_gemm=self.allow_deep_gemm, ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 5a70f3a616c6..61667749a536 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -92,6 +92,8 @@ def __init__(self, tp_size=tp_size, prefix=f"{prefix}.experts") + self.tp_size = tp_size if tp_size is not None else 1 + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape @@ -99,6 +101,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states) + return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index aedfad56d031..c98f2b77c673 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -137,7 +137,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - final_hidden_states = final_hidden_states + if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) From d6eda9b304683746b0f7d76d8a3c751d8437b5f6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 6 Jun 2025 01:48:16 +0000 Subject: [PATCH 24/77] wip hacking Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 2 + .../layers/fused_moe/pplx_prepare_finalize.py | 46 +++++-------------- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a2e7032d39b8..63ed38faf814 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -963,6 +963,8 @@ def apply( if self.use_fp8_w8a8: intermediate_cache1.fill_(0) + #print(f"A1_SCALES {a1q_scale.shape}") + # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, B=w1, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 643f4ab00c68..5cfea8831d56 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -121,25 +121,21 @@ def prepare( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) - # pplx requires 2-d scales even for scalars if a1q_scale is not None: + scalar_scales = a1q_scale.numel() == 1 + + # pplx requires 2-d scales even for scalar scales if a1q_scale.dim() <= 1: - assert a1q_scale.numel() == 1 + assert scalar_scales a1q_scale = a1q_scale.view(1, 1) - #print(f"ORIG {a1q_scale.shape}, {a1q_scale}") - - orig_scale = a1q_scale - orig_a1q_scale_shape = a1q_scale.shape - - # pad out scales if needed - if a1q_scale.numel() == 1: - a1q_scale = a1q_scale.repeat(a1q.shape[1], 4) - - assert a1q_scale.shape[0] == a1q.shape[1] + # pad out scales if needed. TODO (bnell): do for non-scalar scales? + if scalar_scales: + a1q_scale = a1q_scale.repeat(a1q.shape[1], torch.float32.itemsize) - #print(f"FINAL {a1q_scale.shape}, {a1q_scale}") + orig_a_scale_block_shape = a1q_scale.shape[-1] + #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}" assert a1q_scale is None or a1q_scale.ndim == 2, \ f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" @@ -173,20 +169,15 @@ def prepare( expert_x_scale_shape = ( num_local_experts, expert_x.size(1), - #(expert_x.size(2) + block_size - 1) // block_size, - orig_a1q_scale_shape[-1], + (expert_x.size(2) + block_size - 1) // block_size if not scalar_scales else 1, ) - #print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}") - expert_x_scale = torch.zeros( expert_x_scale_shape, dtype=torch.float32, device=expert_x.device, ) - #print(f"YYYYYYYYYYYYYYY {expert_x.shape}") - # This argument is optional, defaults to indices.size(0) # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None @@ -203,22 +194,9 @@ def prepare( if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] - #print(f"ZZZZZZZZZZZZZZ {expert_x_scale.shape}") if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, :orig_a1q_scale_shape[-1]] - from math import prod - if prod(orig_a1q_scale_shape) == 1: - expert_x_scale = expert_x_scale[:, :1, :1] - #print(f"EPT {expert_num_tokens.flatten()}") - #print(f"SCALARIZING!!! {expert_x_scale.shape}, {expert_x_scale.flatten()}") - idx = expert_num_tokens.flatten() != 0 - assert torch.all(expert_x_scale.flatten()[idx] != 0) - #zidx = expert_num_tokens.flatten() == 0 - #assert torch.all(expert_x_scale.flatten()[zidx] == 0) - assert expert_x_scale.ndim == 3 - #expert_x_scale = orig_scale.view(1) - - assert expert_x_scale.ndim == 1 or expert_x_scale.ndim == 3 + expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] + assert expert_x_scale.ndim == 3 return expert_x, expert_x_scale, expert_num_tokens, None, None From f554edabcfec95d065091f9312228bf6e4a04c87 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 10 Jun 2025 21:31:05 +0000 Subject: [PATCH 25/77] cleanup ctor args Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 3 - .../layers/fused_moe/__init__.py | 1 - .../batched_triton_or_deep_gemm_moe.py | 44 +++---- .../layers/fused_moe/cutlass_moe.py | 5 - .../layers/fused_moe/fused_batched_moe.py | 112 +++++++++--------- vllm/model_executor/layers/fused_moe/layer.py | 8 +- .../layers/fused_moe/modular_kernel.py | 12 +- .../layers/fused_moe/pplx_prepare_finalize.py | 2 +- .../compressed_tensors_moe.py | 22 ++-- 9 files changed, 99 insertions(+), 110 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 4e513c813c02..dc03db7a7743 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -430,9 +430,6 @@ def pplx_moe( world_size, rank, dp_size, - quant_dtype=qtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, ) experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 897e6700e7c4..3d40879b4ccb 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -38,7 +38,6 @@ def get_config() -> Optional[dict[str, Any]]: "FusedMoEPrepareAndFinalize", "override_config", "get_config", - "MOE_DP_CHUNK_SIZE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 3682a536cb5c..98b94f3bab9e 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -13,17 +13,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, - max_num_tokens: int, - world_size: int, - dp_size: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, - allow_deep_gemm: bool = False): + def __init__( + self, + max_num_tokens: int, + world_size: int, + dp_size: int, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, + allow_deep_gemm: bool = False + ): assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" @@ -42,8 +44,6 @@ def __init__(self, self.dp_size = dp_size self.allow_deep_gemm = allow_deep_gemm - # BatchedTritonKernel doesn't support block quantization - # at the moment. self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=self.max_num_tokens, world_size=self.world_size, @@ -54,18 +54,21 @@ def __init__(self, use_int4_w4a16=use_int4_w4a16, per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape, - ) if self.block_shape is None else None + ) - is_fp8_128_block_quantized = ( - use_fp8_w8a8 and self.block_shape - == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 + and self.block_shape + == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) self.batched_deep_gemm_experts = BatchedDeepGemmExperts( max_num_tokens=self.max_num_tokens, world_size=self.world_size, dp_size=self.dp_size, - block_shape=self.block_shape, # type: ignore[arg-type] - ) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None + block_shape=self.block_shape, + ) if self.allow_deep_gemm else None + + assert (self.batched_triton_experts is not None or + (self.allow_deep_gemm and self.batched_deep_gemm_experts is not None)) assert (self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None) @@ -114,7 +117,6 @@ def workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: - assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) @@ -142,7 +144,7 @@ def apply( and self.batched_deep_gemm_experts is not None) experts = (self.batched_deep_gemm_experts - if use_batched_deep_gemm_experts else + if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_ids, activation, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index f334a4d03cb1..0ef4e4f767e3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -204,8 +204,6 @@ def run_cutlass_moe_fp8( # TODO (bnell): split class batched vs. non-batched? # maybe remove need for passing aq to workspace_shapes -class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): - class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -362,9 +360,6 @@ def cutlass_moe_fp8( num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( 0) - if out_dtype is None: - out_dtype = a.dtype - fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 63ed38faf814..a5b3478c3c42 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -362,6 +362,8 @@ def invoke_moe_batched_triton_kernel( per_act_token_quant: bool, block_shape: Optional[list[int]] = None): + #print(f"TRITON MOE BATCHED {use_fp8_w8a8}, {per_act_token_quant}, {block_shape}") + assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) @@ -374,35 +376,26 @@ def invoke_moe_batched_triton_kernel( grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) - assert A_scale is None or A_scale.ndim == 1 or A_scale.ndim == 3, f"{0 if A_scale is None else (A_scale.ndim, A_scale.shape)}" - assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else (B_scale.ndim, B_scale.shape)}" - - #print(f"SCALES {A_scale.shape}, {B_scale.shape}") + assert A_scale is None or A_scale.ndim == 3, f"{0 if A_scale is None else A_scale.shape}" + assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}" - stride_bse = 0 - stride_bsk = 0 - stride_bsn = 0 if B_scale is not None: - if B_scale.ndim == 1: - stride_bsk = B_scale.stride(0) - else: - assert B_scale.ndim == 3 - stride_bse = B_scale.stride(0) - stride_bsn = B_scale.stride(1) - stride_bsk = B_scale.stride(2) - - stride_ase = 0 - stride_asm = 0 - stride_ask = 0 - if A_scale is not None: - if A_scale.ndim == 1: - stride_ask = A_scale.stride(0) - else: - assert A_scale.ndim == 3 - stride_ase = A_scale.stride(0) - stride_asm = A_scale.stride(1) - stride_ask = A_scale.stride(2) + stride_bse = B_scale.stride(0) + stride_bsn = B_scale.stride(1) + stride_bsk = B_scale.stride(2) + else: + stride_bse = 0 + stride_bsk = 0 + stride_bsn = 0 + if A_scale is not None: + stride_ase = A_scale.stride(0) + stride_asm = A_scale.stride(1) + stride_ask = A_scale.stride(2) + else: + stride_ase = 0 + stride_asm = 0 + stride_ask = 0 batched_triton_kernel[grid]( A, @@ -522,16 +515,16 @@ def prepare( dtype=b_type, device=a1.device) - if self.quant_dtype is not None: + if quant_dtype is not None: if self.block_shape is not None: - _, block_k = self.block_shape + _, block_k = block_shape k_tiles = (hidden_dim + block_k - 1) // block_k scale_shape = (num_local_experts, self.max_num_tokens, k_tiles) else: - num = self.max_num_tokens if self.per_act_token_quant else 1 + num = self.max_num_tokens if per_act_token_quant else 1 scale_shape = (num_local_experts, num, 1) - #print(f"SCALE_SHAPE {self.block_shape} {b_a1.shape} {scale_shape}") + #print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}") b_a1_scale = torch.zeros( scale_shape, @@ -551,7 +544,7 @@ def prepare( continue rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert - if self.quant_dtype is not None: + if quant_dtype is not None: if a1_scale is not None: assert False, "NYI" rhs_a1_scale = a1_scale[:topks.numel()][topks] @@ -561,11 +554,11 @@ def prepare( moe_kernel_quantize_input( rhs, rhs_a1_scale, - self.quant_dtype, - self.per_act_token_quant, - self.block_shape, + quant_dtype, + per_act_token_quant, + block_shape, )) - if self.block_shape is None and not self.per_act_token_quant: + if block_shape is None and not per_act_token_quant: b_a1_scale[idx] = b_s else: #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") @@ -580,7 +573,7 @@ def prepare( #print(f"A1Q_scale = {b_a1_scale.shape}\n{b_a1_scale}") assert b_a1_scale is None or b_a1_scale.ndim == 3 - return b_a1, b_a1_scale, tokens_per_expert + return b_a1, b_a1_scale, tokens_per_expert, None, None def finalize( self, @@ -653,7 +646,6 @@ def __init__( per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) - assert block_m is None self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size @@ -743,7 +735,7 @@ def batched_moe_kernel_quantize_input( N: int, expert_num_tokens: torch.Tensor, qtype: Optional[torch.dtype], - per_channel_quant: bool, + per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if (True or @@ -753,10 +745,13 @@ def batched_moe_kernel_quantize_input( # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) assert A_scale is None or A_scale.ndim <= 2 - A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, - hidden_dim), A_scale, - qtype, per_channel_quant, - block_shape) + A_q, A_q_scale = moe_kernel_quantize_input( + A.view(-1, hidden_dim), + A_scale, + qtype, + per_act_token_quant, + block_shape + ) A_q = A_q.view(E, -1, hidden_dim) # for e in range(len(expert_num_tokens)): @@ -779,30 +774,33 @@ def batched_moe_kernel_quantize_input( if qtype is not None: assert block_shape is not None A_q = torch.empty_like(A, dtype=qtype) - block_n, block_k = block_shape - n_tiles = ((N // 2) + block_n - 1) // block_n - scale_shape = (E, num_tokens, n_tiles) + + if per_act_token_quant: + assert block_shape is None + scale_shape = (E, num_tokens, 1) + elif block_shape is not None: + block_n, block_k = block_shape + n_tiles = (A.shape[-1] + block_n - 1) // block_n + scale_shape = (E, num_tokens, n_tiles) + else: + scale_shape = (E, 1, 1) + A_q_scale = torch.zeros(scale_shape, dtype=torch.float32, device=A.device) + for e in range(E): num_tokens = expert_num_tokens[e] if num_tokens > 0: A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( A[e, :num_tokens], - A_scale[e, :num_tokens] if A_scale else None, qtype, - per_channel_quant, block_shape) + A_scale[e, :min(num_tokens,A_scale.shape[1])] if A_scale is not None else None, + qtype, + per_act_token_quant, + block_shape + ) A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale - if A_q_scale is not None: - if A_q_scale.numel() == 1: - A_q_scale = A_q_scale.view(1) - A_q_scale = torch.repeat_interleave(A_q_scale, E, dim=0).view(E, 1, 1) - else: - A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) - - #print(f"A2Q_SCALE {A_q_scale.shape}\n{A_q_scale}") - return A_q, A_q_scale else: return A, A_scale @@ -993,7 +991,7 @@ def apply( #print(f"BATCHED ACT {intermediate_cache2.shape}\n{intermediate_cache2}") qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens, + intermediate_cache2, a2_scale, max_num_tokens, E, N, expert_num_tokens, self.quant_dtype, self.per_act_token_quant, self.block_shape) invoke_moe_batched_triton_kernel(A=qintermediate_cache2, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 37948b83741f..cf4daf6c6d91 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -36,6 +36,11 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from .modular_kernel import (FusedMoEModularKernel, + FusedMoEPrepareAndFinalize, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) + if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts @@ -92,6 +97,8 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None if moe.use_pplx_kernels: + block_shape = quant_config.weight_block_size if quant_config is not None else None + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, moe.hidden_dim, @@ -396,7 +403,6 @@ def forward_cuda( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c1b8b2317c34..a3f2f055dc39 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -83,6 +83,7 @@ def _moe_problem_size( return E, M, N, K, topk +# TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEActivationFormat(Enum): """ @@ -102,16 +103,6 @@ class FusedMoEPrepareAndFinalize(ABC): described above. """ - def __init__( - self, - quant_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - block_shape: Optional[list[int]], - ): - self.quant_dtype = quant_dtype - self.per_act_token_quant = per_act_token_quant - self.block_shape = block_shape - @abstractmethod def prepare( self, @@ -216,6 +207,7 @@ def __init__( per_act_token_quant: bool, block_shape: Optional[list[int]], ): + assert not per_act_token_quant or block_shape is None self.quant_dtype = quant_dtype self.per_act_token_quant = per_act_token_quant self.block_shape = block_shape diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5cfea8831d56..6e84e16bcafe 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -153,7 +153,7 @@ def prepare( ) num_dp = self.world_size // self.dp_size - expert_x = torch.empty( + expert_x = torch.zeros( (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), dtype=a1q.dtype, device=device, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a0d66c65ed34..fac55f3c7441 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -402,8 +402,6 @@ def __init__( self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() - self.use_pplx_kernels = False - def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -581,23 +579,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # XXXXXXXXXX def select_gemm_impl(self, prepare_finalize): from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts) - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + BatchedTritonExperts) assert not self.rocm_aiter_moe_enabled and not self.use_marlin - assert isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)) - logger.debug("BatchedTritonExperts(%s)", self.__classname__.__name__) all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - self.use_pplx_kernels = True + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + use_batched_experts = max_num_tokens_per_rank is not None + + assert use_batched_experts + return BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, + max_num_tokens=max_num_tokens_per_rank, world_size=all2all_manager.world_size, dp_size=all2all_manager.tp_group.world_size, use_fp8_w8a8=True, @@ -645,7 +642,8 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=torch.uint32 if self.use_pplx_kernels else None) + indices_type=self.topk_indices_dtype, + ) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts_func( @@ -869,6 +867,8 @@ def select_gemm_impl( num_experts = (moe.num_local_experts if use_batched_format else moe.num_experts) + logger.debug("CutlassExpertsFp8(%s)", self.__classname__.__name__) + experts = CutlassExpertsFp8( num_experts, moe.in_dtype, From 8e70e606580ee4e7f5cc6ad39be399b883e3645d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 11 Jun 2025 02:30:33 +0000 Subject: [PATCH 26/77] wip Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 7 ++++--- .../layers/fused_moe/fused_batched_moe.py | 17 +++++++++++------ vllm/model_executor/layers/fused_moe/layer.py | 3 +-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 9eec7cbd0d14..3649d25eb721 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -72,7 +72,7 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - if config.in_dtype == torch.torch.float8_e4m3fn: + if config.in_dtype == torch.float8_e4m3fn: config_in_dtype = torch.bfloat16 else: config_in_dtype = config.in_dtype @@ -103,8 +103,9 @@ def make_tensors(config: BatchedMMConfig): [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024]) -@pytest.mark.parametrize("dtype", - [torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "dtype", + [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a5b3478c3c42..ca7d08ba4614 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -62,7 +62,6 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - # + (expert_id * stride_ase) ?? a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm) #+ (expert_id * stride_ase) offs_bsn = offs_n // group_n b_scale_ptrs = (b_scale_ptr + offs_bsn * stride_bsn) + expert_id * stride_bse @@ -377,12 +376,18 @@ def invoke_moe_batched_triton_kernel( triton.cdiv(B.size(1), BLOCK_N)) assert A_scale is None or A_scale.ndim == 3, f"{0 if A_scale is None else A_scale.shape}" - assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}" + assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}" + #assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else (A.shape, B_scale.shape)}" if B_scale is not None: - stride_bse = B_scale.stride(0) - stride_bsn = B_scale.stride(1) - stride_bsk = B_scale.stride(2) + if B_scale.ndim == 1: + stride_bse = 1 + stride_bsn = 0 + stride_bsk = 0 + else: + stride_bse = B_scale.stride(0) + stride_bsn = B_scale.stride(1) + stride_bsk = B_scale.stride(2) else: stride_bse = 0 stride_bsk = 0 @@ -516,7 +521,7 @@ def prepare( device=a1.device) if quant_dtype is not None: - if self.block_shape is not None: + if block_shape is not None: _, block_k = block_shape k_tiles = (hidden_dim + block_k - 1) // block_k scale_shape = (num_local_experts, self.max_num_tokens, k_tiles) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cf4daf6c6d91..2e2fdc3b7ae2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -97,8 +97,6 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None if moe.use_pplx_kernels: - block_shape = quant_config.weight_block_size if quant_config is not None else None - hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, moe.hidden_dim, @@ -403,6 +401,7 @@ def forward_cuda( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, From 5f5e9a389e452b82212c3ea874f407653ab89410 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 11 Jun 2025 14:52:44 +0000 Subject: [PATCH 27/77] fixes Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 12 +++++++----- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 3 +-- .../layers/fused_moe/prepare_finalize.py | 7 +++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index dc03db7a7743..56ceda5d518e 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -432,11 +432,13 @@ def pplx_moe( dp_size, ) - experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=qtype == torch.float8_e4m3fn, - block_shape=block_shape) + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=qtype==torch.float8_e4m3fn, + block_shape=block_shape + ) fused_experts = FusedMoEModularKernel( prepare_finalize, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0ef4e4f767e3..ba421a21c4ce 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -357,8 +357,7 @@ def cutlass_moe_fp8( """ per_out_ch = w1_scale.numel() != w1_q.size(0) - num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( - 0) + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 9e4be82f6c1f..00927283dc8a 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -45,9 +45,16 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) +<<<<<<< HEAD a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) +======= + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + quant_dtype, + per_act_token_quant, + block_shape) +>>>>>>> 60564d0a4 (fixes) return a1q, a1q_scale, None, None, None From 9d30bcca6a96b0ef90d90ad709f212785caf78f2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 12 Jun 2025 21:43:38 +0000 Subject: [PATCH 28/77] refactoring Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 1 + .../layers/fused_moe/batched_deep_gemm_moe.py | 2 + .../batched_triton_or_deep_gemm_moe.py | 2 + .../layers/fused_moe/fused_batched_moe.py | 39 ++++++++++--------- .../layers/fused_moe/fused_moe.py | 2 + .../layers/fused_moe/modular_kernel.py | 28 +++++++++---- .../layers/fused_moe/pplx_prepare_finalize.py | 18 ++++++--- .../layers/fused_moe/prepare_finalize.py | 12 ++---- .../layers/fused_moe/triton_deep_gemm_moe.py | 4 +- 9 files changed, 67 insertions(+), 41 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 56ceda5d518e..34811f38a838 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -40,6 +40,7 @@ from tests.kernels.moe.utils import ( torch_moe2, naive_batched_moe, + make_test_weights, ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 6b08f32dff18..b11c3855481e 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,6 +7,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig) from vllm.triton_utils import tl, triton logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 98b94f3bab9e..949b67e415ff 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -9,6 +9,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index ca7d08ba4614..a66aaa96cbb5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -10,9 +10,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, get_config_quant_dtype, try_get_optimal_moe_config) + get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig) @triton.jit @@ -520,13 +522,13 @@ def prepare( dtype=b_type, device=a1.device) - if quant_dtype is not None: - if block_shape is not None: - _, block_k = block_shape + if quant_config.quant_dtype is not None: + if quant_config.block_shape is not None: + _, block_k = quant_config.block_shape k_tiles = (hidden_dim + block_k - 1) // block_k scale_shape = (num_local_experts, self.max_num_tokens, k_tiles) else: - num = self.max_num_tokens if per_act_token_quant else 1 + num = self.max_num_tokens if quant_config.per_act_token_quant else 1 scale_shape = (num_local_experts, num, 1) #print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}") @@ -549,7 +551,7 @@ def prepare( continue rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert - if quant_dtype is not None: + if quant_config.quant_dtype is not None: if a1_scale is not None: assert False, "NYI" rhs_a1_scale = a1_scale[:topks.numel()][topks] @@ -559,11 +561,11 @@ def prepare( moe_kernel_quantize_input( rhs, rhs_a1_scale, - quant_dtype, - per_act_token_quant, - block_shape, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, )) - if block_shape is None and not per_act_token_quant: + if quant_config.block_shape is None and not quant_config.per_act_token_quant: b_a1_scale[idx] = b_s else: #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") @@ -640,16 +642,15 @@ def __init__( assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" - quant_dtype = get_config_quant_dtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - ) super().__init__( - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) ) self.max_num_tokens = max_num_tokens self.world_size = world_size diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 25a332dc4f04..c06619d8539f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,6 +9,8 @@ import torch import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig) import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a3f2f055dc39..794623f00da2 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -203,14 +203,28 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, - quant_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - block_shape: Optional[list[int]], + quant_config: Optional[FusedMoEQuantConfig], ): - assert not per_act_token_quant or block_shape is None - self.quant_dtype = quant_dtype - self.per_act_token_quant = per_act_token_quant - self.block_shape = block_shape + if quant_config is not None: + self.quant_config = quant_config + else: + self.quant_config = FusedMoEQuantConfig() + + @property + def quant_dtype(self): + return self.quant_config.quant_dtype + + @property + def block_shape(self): + return self.quant_config.block_shape + + @property + def per_act_token_quant(self): + return self.quant_config.per_act_token_quant + + @property + def per_out_ch_quant(self): + return self.quant_config.per_out_ch_quant def __init__( self, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 6e84e16bcafe..b1613603cd72 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -129,11 +129,12 @@ def prepare( assert scalar_scales a1q_scale = a1q_scale.view(1, 1) + orig_a_scale_block_shape = a1q_scale.shape[-1] + # pad out scales if needed. TODO (bnell): do for non-scalar scales? if scalar_scales: - a1q_scale = a1q_scale.repeat(a1q.shape[1], torch.float32.itemsize) - - orig_a_scale_block_shape = a1q_scale.shape[-1] + #print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") + a1q_scale = a1q_scale.repeat(a1q.shape[1], 4*torch.float32.itemsize) #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}" @@ -153,6 +154,7 @@ def prepare( ) num_dp = self.world_size // self.dp_size + #print(f"EXPERT_X {(num_local_experts, self.max_num_tokens * num_dp, hidden_dim)}, {a1q.dtype}, {device}") expert_x = torch.zeros( (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), dtype=a1q.dtype, @@ -172,6 +174,8 @@ def prepare( (expert_x.size(2) + block_size - 1) // block_size if not scalar_scales else 1, ) + #print(f"EXPERT_X_SCALE {expert_x_scale_shape}") + expert_x_scale = torch.zeros( expert_x_scale_shape, dtype=torch.float32, @@ -182,6 +186,8 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None + #print(f"DISPATCH X={expert_x.shape}, X_SCALE={expert_x_scale.shape}, A={a1q.shape}, A_SCALE={a1q_scale.shape}, TOPK={topk_ids}") + self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -191,6 +197,8 @@ def prepare( indices=topk_ids, bound_m=bound_m, ) + #print(f"DISPATCH DONE {device}") + if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] @@ -223,10 +231,10 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - #print("CCCCCCCCCCCCCCCCCCCC") - + #print(f"COMBINE {output.device}") self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) + #print(f"COMBINE DONE {output.device}") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 00927283dc8a..e222d2abc828 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -45,16 +45,10 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) -<<<<<<< HEAD - a1q, a1q_scale = moe_kernel_quantize_input( - a1, a1_scale, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) -======= a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, - quant_dtype, - per_act_token_quant, - block_shape) ->>>>>>> 60564d0a4 (fixes) + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape) return a1q, a1q_scale, None, None, None diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 4a1c7d4be1ba..cbd32aed1423 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -9,7 +9,9 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, get_config_quant_dtype) + TritonExperts) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig) class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): From 9fd583359a7c264ef6e10c999bbb4db85a0e91ca Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 12 Jun 2025 21:46:05 +0000 Subject: [PATCH 29/77] lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 2 - tests/kernels/moe/test_pplx_moe.py | 23 +- .../batched_triton_or_deep_gemm_moe.py | 34 +-- .../layers/fused_moe/cutlass_moe.py | 3 +- .../layers/fused_moe/fused_batched_moe.py | 259 +++++++++--------- .../layers/fused_moe/fused_moe.py | 16 +- vllm/model_executor/layers/fused_moe/layer.py | 1 - .../layers/fused_moe/modular_kernel.py | 1 + .../layers/fused_moe/pplx_prepare_finalize.py | 6 +- .../layers/fused_moe/prepare_finalize.py | 7 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 7 +- .../compressed_tensors_moe.py | 8 +- .../model_executor/layers/quantization/fp8.py | 4 +- .../layers/quantization/utils/fp8_utils.py | 3 +- 14 files changed, 176 insertions(+), 198 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 3649d25eb721..b106974733cb 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -16,8 +16,6 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 34811f38a838..55aab8e2c88b 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,28 +28,17 @@ from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) from vllm.platforms import current_platform from vllm.utils import round_up from .parallel_utils import ProcessGroupInfo, parallel_launch -from tests.kernels.moe.utils import ( - torch_moe2, - naive_batched_moe, - make_test_weights, -) - requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", ) - PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), (222, 2048, 1024)] @@ -433,13 +422,11 @@ def pplx_moe( dp_size, ) - experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=qtype==torch.float8_e4m3fn, - block_shape=block_shape - ) + experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=qtype == torch.float8_e4m3fn, + block_shape=block_shape) fused_experts = FusedMoEModularKernel( prepare_finalize, diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 949b67e415ff..13bab1bf37dd 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -9,25 +9,21 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - max_num_tokens: int, - world_size: int, - dp_size: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, - allow_deep_gemm: bool = False - ): + def __init__(self, + max_num_tokens: int, + world_size: int, + dp_size: int, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, + allow_deep_gemm: bool = False): assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" @@ -69,8 +65,9 @@ def __init__( block_shape=self.block_shape, ) if self.allow_deep_gemm else None - assert (self.batched_triton_experts is not None or - (self.allow_deep_gemm and self.batched_deep_gemm_experts is not None)) + assert (self.batched_triton_experts is not None + or (self.allow_deep_gemm + and self.batched_deep_gemm_experts is not None)) assert (self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None) @@ -146,8 +143,7 @@ def apply( and self.batched_deep_gemm_experts is not None) experts = (self.batched_deep_gemm_experts - if self.allow_deep_gemm else - self.batched_triton_experts) + if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_ids, activation, global_num_experts, expert_map, w1_scale, w2_scale, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index ba421a21c4ce..0ef4e4f767e3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -357,7 +357,8 @@ def cutlass_moe_fp8( """ per_out_ch = w1_scale.numel() != w1_q.size(0) - num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( + 0) fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a66aaa96cbb5..6845893720a6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -13,45 +13,43 @@ get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig) @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_ase, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr, - per_channel_quant: tl.constexpr, + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, ): offs_k = tl.arange(0, BLOCK_K) @@ -64,23 +62,26 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm) #+ (expert_id * stride_ase) + a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm + ) #+ (expert_id * stride_ase) offs_bsn = offs_n // group_n - b_scale_ptrs = (b_scale_ptr + offs_bsn * stride_bsn) + expert_id * stride_bse + b_scale_ptrs = (b_scale_ptr + + offs_bsn * stride_bsn) + expert_id * stride_bse # channel-wise elif per_channel_quant: # TODO: probably not correct - b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn[None, :] * stride_bsn + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn[ + None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations # + (expert_id * stride_ase)?? a_scale_ptrs = a_scale_ptr + offs_m * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None] + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] # tensor-wise else: - a_scale = tl.load(a_scale_ptr)# + (expert_id * stride_ase) + a_scale = tl.load(a_scale_ptr) # + (expert_id * stride_ase) b_scale = tl.load(b_scale_ptr + expert_id * stride_bse) # ----------------------------------------------------------- @@ -137,43 +138,43 @@ def moe_mmk( @triton.jit def expert_triton_kernel( - a_ptr, #[max_tokens, K] - b_ptr, #[K, N] - c_ptr, #[max_tokens, N] - expert_id, - compute_type: tl.constexpr, - # Dimensions - M, - N, - K, - # Quantization data - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # strides - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_ase, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n, - group_k, - # Quantization schemes - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - per_channel_quant: tl.constexpr, - # Kernel config - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, + a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ): offs_m = tl.arange(0, BLOCK_M) @@ -300,7 +301,8 @@ def batched_triton_kernel( if use_fp8_w8a8: # block-wise if (group_k > 0 and group_n > 0) or per_channel_quant: - a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + cta_m_start * stride_asm + a_scale_ptr = a_scale_ptr + (expert_id * + stride_ase) + cta_m_start * stride_asm #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) # + cta_n_start * stride_bsn? # channel-wise or tensor-wise else: @@ -533,10 +535,9 @@ def prepare( #print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}") - b_a1_scale = torch.zeros( - scale_shape, - dtype=torch.float32, - device=a1.device) + b_a1_scale = torch.zeros(scale_shape, + dtype=torch.float32, + device=a1.device) else: assert a1_scale is None b_a1_scale = None @@ -557,19 +558,19 @@ def prepare( rhs_a1_scale = a1_scale[:topks.numel()][topks] else: rhs_a1_scale = None - b_a1[idx, :rows, :], b_s = ( - moe_kernel_quantize_input( - rhs, - rhs_a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - )) + b_a1[idx, :rows, :], b_s = (moe_kernel_quantize_input( + rhs, + rhs_a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + )) if quant_config.block_shape is None and not quant_config.per_act_token_quant: b_a1_scale[idx] = b_s else: #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") - assert rows == b_s.shape[0] and b_a1_scale.shape[-1] == b_s.shape[-1] + assert rows == b_s.shape[0] and b_a1_scale.shape[ + -1] == b_s.shape[-1] b_a1_scale[idx, :rows] = b_s else: b_a1[idx, :rows, :] = rhs @@ -650,8 +651,7 @@ def __init__( use_int4_w4a16=use_int4_w4a16, per_act_token_quant=per_act_token_quant, block_shape=block_shape, - ) - ) + )) self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size @@ -744,20 +744,16 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (True or - torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): + if (True or torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) assert A_scale is None or A_scale.ndim <= 2 - A_q, A_q_scale = moe_kernel_quantize_input( - A.view(-1, hidden_dim), - A_scale, - qtype, - per_act_token_quant, - block_shape - ) + A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, + hidden_dim), A_scale, + qtype, per_act_token_quant, + block_shape) A_q = A_q.view(E, -1, hidden_dim) # for e in range(len(expert_num_tokens)): @@ -767,7 +763,8 @@ def batched_moe_kernel_quantize_input( if A_q_scale is not None: if A_q_scale.numel() == 1: A_q_scale = A_q_scale.view(1) - A_q_scale = torch.repeat_interleave(A_q_scale, E, dim=0).view(E, 1, 1) + A_q_scale = torch.repeat_interleave(A_q_scale, E, + dim=0).view(E, 1, 1) else: A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) @@ -800,11 +797,9 @@ def batched_moe_kernel_quantize_input( if num_tokens > 0: A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( A[e, :num_tokens], - A_scale[e, :min(num_tokens,A_scale.shape[1])] if A_scale is not None else None, - qtype, - per_act_token_quant, - block_shape - ) + A_scale[e, :min(num_tokens, A_scale.shape[1])] + if A_scale is not None else None, qtype, + per_act_token_quant, block_shape) A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale @@ -970,29 +965,39 @@ def apply( #print(f"A1_SCALES {a1q_scale.shape}") # MM1 - invoke_moe_batched_triton_kernel(A=hidden_states, - B=w1, - C=intermediate_cache1, - expert_num_tokens=expert_num_tokens, - compute_type=compute_type, - A_scale=a1q_scale, - B_scale=w1_scale, - B_zp=w1_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - config=config, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) + invoke_moe_batched_triton_kernel( + A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape) intermediate_cache2.fill_(0) # TODO: would be nice to use expert_num_tokens here to reduce # garbage compute - self.activation( - activation, - intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + if False: + # TODO: check expert_num_tokens + tmp = torch.empty_like(intermediate_cache2[0]) + for e in range(E): + num_tokens = expert_num_tokens[e] + self.activation(activation, tmp[:num_tokens], + intermediate_cache1[e, :num_tokens]) + intermediate_cache2[e, :num_tokens] = tmp[:num_tokens] + else: + self.activation( + activation, + intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) #print(f"BATCHED ACT {intermediate_cache2.shape}\n{intermediate_cache2}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c06619d8539f..d03f9ac4a992 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,8 +9,6 @@ import torch import vllm.envs as envs -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig) import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -468,7 +466,6 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) - def prepare_scales( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -491,8 +488,7 @@ def prepare_scales( b_a1 = torch.zeros( (num_local_experts, max_num_tokens, hidden_dim), - dtype=quant_dtype - if quant_dtype is not None else a1.dtype, + dtype=quant_dtype if quant_dtype is not None else a1.dtype, device=a1.device) if quant_dtype is not None: @@ -504,10 +500,9 @@ def prepare_scales( num = 1 scale_shape = (num_local_experts, num, 1) - b_a1_scale = torch.zeros( - scale_shape, - dtype=torch.float32, - device=a1.device) + b_a1_scale = torch.zeros(scale_shape, + dtype=torch.float32, + device=a1.device) else: assert a1_scale is None b_a1_scale = None @@ -526,7 +521,8 @@ def prepare_scales( if block_shape is None: b_a1_scale[idx] = rhs_a1_scale else: - assert rows == rhs_a1_scale.shape[0] and b_a1_scale.shape[-1] == rhs_a1_scale.shape[-1] + assert rows == rhs_a1_scale.shape[0] and b_a1_scale.shape[ + -1] == rhs_a1_scale.shape[-1] b_a1_scale[idx, :rows] = rhs_a1_scale tokens_per_expert[idx] = rows diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2e2fdc3b7ae2..959c7417f87f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -37,7 +37,6 @@ from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from .modular_kernel import (FusedMoEModularKernel, - FusedMoEPrepareAndFinalize, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 794623f00da2..a6945933a5be 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -83,6 +83,7 @@ def _moe_problem_size( return E, M, N, K, topk + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEActivationFormat(Enum): diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b1613603cd72..c520d1fbc043 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -134,7 +134,8 @@ def prepare( # pad out scales if needed. TODO (bnell): do for non-scalar scales? if scalar_scales: #print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") - a1q_scale = a1q_scale.repeat(a1q.shape[1], 4*torch.float32.itemsize) + a1q_scale = a1q_scale.repeat(a1q.shape[1], + 4 * torch.float32.itemsize) #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}" @@ -171,7 +172,8 @@ def prepare( expert_x_scale_shape = ( num_local_experts, expert_x.size(1), - (expert_x.size(2) + block_size - 1) // block_size if not scalar_scales else 1, + (expert_x.size(2) + block_size - 1) // + block_size if not scalar_scales else 1, ) #print(f"EXPERT_X_SCALE {expert_x_scale_shape}") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index e222d2abc828..9e4be82f6c1f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -45,10 +45,9 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, a1_scale, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index cbd32aed1423..4a55f6cffac2 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -8,10 +8,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -84,7 +81,7 @@ def workspace_shapes( # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if (self.allow_deep_gemm and N > 512 - and _valid_deep_gemm_shape(M, N, K)): + and _valid_deep_gemm_shape(M, N, K)): return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fac55f3c7441..f0f2f8a572c1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,9 +3,6 @@ import enum from enum import Enum - -import functools - from typing import Callable, Optional import torch @@ -600,8 +597,7 @@ def select_gemm_impl(self, prepare_finalize): use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN - ), + self.input_quant.strategy == QuantizationStrategy.TOKEN), ) def apply( @@ -871,7 +867,7 @@ def select_gemm_impl( experts = CutlassExpertsFp8( num_experts, - moe.in_dtype, + None, #moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, use_batched_format=use_batched_format, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f662b755fcba..0c01a188dbdc 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,7 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat, @@ -820,7 +820,7 @@ def select_gemm_impl( return TritonOrDeepGemmExperts( use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, - per_act_token=False, #? + per_act_token=False, #? allow_deep_gemm=self.allow_deep_gemm, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 68b162331da5..1f85db685650 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -593,7 +593,8 @@ def w8a8_block_fp8_matmul( assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], f"{block_k}, {triton.cdiv(A.shape[-1], block_k)} == As.shape[-1]" + assert triton.cdiv(A.shape[-1], block_k) == As.shape[ + -1], f"{block_k}, {triton.cdiv(A.shape[-1], block_k)} == As.shape[-1]" M = A.numel() // A.shape[-1] assert B.ndim == 2 and Bs.ndim == 2 From 16a4d7f4a1beab676db7f2bb6cd6ad48c1a555e3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 12 Jun 2025 22:13:54 +0000 Subject: [PATCH 30/77] lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d03f9ac4a992..64d24f86627b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -478,7 +478,7 @@ def prepare_scales( from vllm.utils import round_up max_num_tokens = round_up(a1.shape[0], 64) num_tokens, hidden_dim = a1.size() - topk = topk_ids.size(1) + #topk = topk_ids.size(1) tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 959c7417f87f..df8014b5656b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,15 +31,15 @@ is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx -from .modular_kernel import (FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) - if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts From c21e4df65e67bbf7bc3a75bb3dced2aef7b0b7bd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 12 Jun 2025 22:24:19 +0000 Subject: [PATCH 31/77] lint Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 18 ++++++++++++------ vllm/model_executor/layers/fused_moe/layer.py | 4 ---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 6845893720a6..9011e0879245 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -303,7 +303,8 @@ def batched_triton_kernel( if (group_k > 0 and group_n > 0) or per_channel_quant: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + cta_m_start * stride_asm - #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) # + cta_n_start * stride_bsn? + #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + # (?) b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn # channel-wise or tensor-wise else: a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) @@ -379,9 +380,10 @@ def invoke_moe_batched_triton_kernel( grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) - assert A_scale is None or A_scale.ndim == 3, f"{0 if A_scale is None else A_scale.shape}" - assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}" - #assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else (A.shape, B_scale.shape)}" + assert A_scale is None or A_scale.ndim == 3, ( + f"{0 if A_scale is None else A_scale.shape}") + assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, ( + f"{0 if B_scale is None else B_scale.shape}") if B_scale is not None: if B_scale.ndim == 1: @@ -530,7 +532,10 @@ def prepare( k_tiles = (hidden_dim + block_k - 1) // block_k scale_shape = (num_local_experts, self.max_num_tokens, k_tiles) else: - num = self.max_num_tokens if quant_config.per_act_token_quant else 1 + if quant_config.per_act_token_quant: + num = self.max_num_tokens + else: + num = 1 scale_shape = (num_local_experts, num, 1) #print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}") @@ -565,7 +570,8 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, )) - if quant_config.block_shape is None and not quant_config.per_act_token_quant: + if (quant_config.block_shape is None and + not quant_config.per_act_token_quant): b_a1_scale[idx] = b_s else: #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index df8014b5656b..37948b83741f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,10 +31,6 @@ is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum From 82a0b1e060f132feb3a4534b3676f2f02d724804 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 12 Jun 2025 22:27:51 +0000 Subject: [PATCH 32/77] lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9011e0879245..dc059c079a16 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -570,8 +570,8 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, )) - if (quant_config.block_shape is None and - not quant_config.per_act_token_quant): + if (quant_config.block_shape is None + and not quant_config.per_act_token_quant): b_a1_scale[idx] = b_s else: #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") From 39c9b5eb4b6aa9e3da7eb348b258a91c0cd97afa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 13 Jun 2025 03:10:11 +0000 Subject: [PATCH 33/77] fix merge. split up int8/fp8 moe tests Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 8 ++++---- .../layers/fused_moe/triton_deep_gemm_moe.py | 3 +-- vllm/model_executor/layers/quantization/fp8.py | 3 --- vllm/model_executor/models/qwen3_moe.py | 2 -- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a6945933a5be..c51933069f50 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -212,19 +212,19 @@ def __init__( self.quant_config = FusedMoEQuantConfig() @property - def quant_dtype(self): + def quant_dtype(self) -> Optional[torch.dtype]: return self.quant_config.quant_dtype @property - def block_shape(self): + def block_shape(self) -> Optional[list[int]]: return self.quant_config.block_shape @property - def per_act_token_quant(self): + def per_act_token_quant(self) -> bool: return self.quant_config.per_act_token_quant @property - def per_out_ch_quant(self): + def per_out_ch_quant(self) -> bool: return self.quant_config.per_out_ch_quant def __init__( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 4a55f6cffac2..c88788ac1bfe 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -80,8 +80,7 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if (self.allow_deep_gemm and N > 512 - and _valid_deep_gemm_shape(M, N, K)): + if (self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K)): return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0c01a188dbdc..28caa1ceaffc 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -481,9 +481,6 @@ def __init__(self, quant_config: Fp8Config): block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) - self.use_pplx_kernels = False - self.rocm_aiter_moe_enabled = False - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index c98f2b77c673..ff182aadf738 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -110,8 +110,6 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") - logger.info("MoE quant config %s", quant_config.__dict__) - self.experts = FusedMoE(num_experts=config.num_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, From c03676366941830b4e974a520287abab267f73b3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 18 Jun 2025 21:54:25 +0000 Subject: [PATCH 34/77] wip Signed-off-by: Bill Nell --- .../layers/fused_moe/batched_triton_or_deep_gemm_moe.py | 6 +----- .../model_executor/layers/fused_moe/fused_batched_moe.py | 9 --------- .../layers/fused_moe/triton_deep_gemm_moe.py | 2 +- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 13bab1bf37dd..530d30023ca8 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -62,13 +62,9 @@ def __init__(self, max_num_tokens=self.max_num_tokens, world_size=self.world_size, dp_size=self.dp_size, - block_shape=self.block_shape, + block_shape=self.block_shape, # type: ignore[arg-type] ) if self.allow_deep_gemm else None - assert (self.batched_triton_experts is not None - or (self.allow_deep_gemm - and self.batched_deep_gemm_experts is not None)) - assert (self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index dc059c079a16..2b1c97919562 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -649,15 +649,6 @@ def __init__( assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index c88788ac1bfe..a12dee4885eb 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -80,7 +80,7 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if (self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K)): + if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: From 861500e66a59a1bfd95bb0ab1baf8f1d69de59a0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 20 Jun 2025 19:02:48 +0000 Subject: [PATCH 35/77] cleanup Signed-off-by: Bill Nell --- tests/kernels/moe/utils.py | 1 - .../device_communicators/cuda_communicator.py | 2 +- .../batched_triton_or_deep_gemm_moe.py | 4 +-- vllm/model_executor/layers/fused_moe/layer.py | 1 + .../layers/fused_moe/modular_kernel.py | 27 ------------------- .../layers/fused_moe/pplx_prepare_finalize.py | 4 +-- .../layers/fused_moe/triton_deep_gemm_moe.py | 1 + .../compressed_tensors_moe.py | 15 +++++------ .../model_executor/layers/quantization/fp8.py | 1 - .../layers/quantization/utils/fp8_utils.py | 3 +-- 10 files changed, 12 insertions(+), 47 deletions(-) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 75915457896b..5b1048797447 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -16,7 +16,6 @@ moe_kernel_quantize_input) from vllm.utils import round_up -from tests.kernels.quant_utils import native_w8a8_block_matmul def triton_moe( a: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4071802e5288..3958d566b174 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -71,7 +71,7 @@ def __init__(self, device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend == "naive" or len(all2all_backend) == 0: + if all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 530d30023ca8..65bd4f49b57f 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -112,6 +112,7 @@ def workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: + assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) @@ -135,9 +136,6 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ): - use_batched_deep_gemm_experts = (self.allow_deep_gemm - and self.batched_deep_gemm_experts - is not None) experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 37948b83741f..8501abd9e609 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -732,6 +732,7 @@ def __init__( num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, quant_config=quant_config, ) self.moe_config = moe diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c51933069f50..2ffb4d328eca 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -84,8 +84,6 @@ def _moe_problem_size( return E, M, N, K, topk -# TODO: pass FusedMoEParallelConfig in as ctor parameter? - class FusedMoEActivationFormat(Enum): """ The standard activation format (num_tokens, hidden dim). @@ -202,31 +200,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC): above. """ - def __init__( - self, - quant_config: Optional[FusedMoEQuantConfig], - ): - if quant_config is not None: - self.quant_config = quant_config - else: - self.quant_config = FusedMoEQuantConfig() - - @property - def quant_dtype(self) -> Optional[torch.dtype]: - return self.quant_config.quant_dtype - - @property - def block_shape(self) -> Optional[list[int]]: - return self.quant_config.block_shape - - @property - def per_act_token_quant(self) -> bool: - return self.quant_config.per_act_token_quant - - @property - def per_out_ch_quant(self) -> bool: - return self.quant_config.per_out_ch_quant - def __init__( self, quant_config: Optional[FusedMoEQuantConfig], diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index c520d1fbc043..fdb124b18ea0 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -112,7 +112,6 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) - repeat_cols = 4 repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) a1q, a1q_scale = moe_kernel_quantize_input( @@ -156,7 +155,7 @@ def prepare( num_dp = self.world_size // self.dp_size #print(f"EXPERT_X {(num_local_experts, self.max_num_tokens * num_dp, hidden_dim)}, {a1q.dtype}, {device}") - expert_x = torch.zeros( + expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), dtype=a1q.dtype, device=device, @@ -168,7 +167,6 @@ def prepare( block_size = (quant_config.block_shape[1] if quant_config. block_shape is not None else 1) * float32_size - expert_x_scale_shape = ( num_local_experts, expert_x.size(1), diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index a12dee4885eb..e660376ebe6b 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -81,6 +81,7 @@ def workspace_shapes( # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f0f2f8a572c1..3f0a1c2634d6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,7 +12,6 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed import get_ep_group from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, @@ -573,7 +572,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: self.fused_experts_func = fused_experts - # XXXXXXXXXX def select_gemm_impl(self, prepare_finalize): from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) @@ -582,18 +580,17 @@ def select_gemm_impl(self, prepare_finalize): logger.debug("BatchedTritonExperts(%s)", self.__classname__.__name__) - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None + use_batched_format = (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts) - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - use_batched_experts = max_num_tokens_per_rank is not None + assert use_batched_format - assert use_batched_experts + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() return BatchedTritonExperts( max_num_tokens=max_num_tokens_per_rank, - world_size=all2all_manager.world_size, - dp_size=all2all_manager.tp_group.world_size, + world_size=moe.world_size, + dp_size=moe.dp_size, use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, per_act_token_quant=( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 28caa1ceaffc..0295f5e2a1c8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -817,7 +817,6 @@ def select_gemm_impl( return TritonOrDeepGemmExperts( use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, - per_act_token=False, #? allow_deep_gemm=self.allow_deep_gemm, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 1f85db685650..c38a445c571b 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -593,8 +593,7 @@ def w8a8_block_fp8_matmul( assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[ - -1], f"{block_k}, {triton.cdiv(A.shape[-1], block_k)} == As.shape[-1]" + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] M = A.numel() // A.shape[-1] assert B.ndim == 2 and Bs.ndim == 2 From ae39492d5b6c400b3ad352d3fcdd27241db11601 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 20 Jun 2025 19:22:49 +0000 Subject: [PATCH 36/77] cleanup Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index b106974733cb..d4ce66a62419 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -70,11 +70,6 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - if config.in_dtype == torch.float8_e4m3fn: - config_in_dtype = torch.bfloat16 - else: - config_in_dtype = config.in_dtype - A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", @@ -183,7 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 }, - per_act_token_quant=False, + per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) @@ -294,11 +289,6 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - torch.testing.assert_close(triton_output, - baseline_output, - atol=2e-2, - rtol=2e-2) - #print(f"TORCH {baseline_output.shape}\n{baseline_output}") #print(f"TRITON {triton_output.shape}\n{triton_output}") #print(f"BATCHED {batched_output.shape}\n{batched_output}") From ae45963c744e7b70dc813fbfa000ed971816195c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 21 Jun 2025 02:01:57 +0000 Subject: [PATCH 37/77] fixes after merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 30 +++++++++++++++---- .../fused_moe/deepep_ll_prepare_finalize.py | 10 ++----- .../layers/fused_moe/fused_batched_moe.py | 23 +++++++++----- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 55aab8e2c88b..96d64b53dfee 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform from vllm.utils import round_up @@ -198,6 +199,7 @@ def pplx_prepare_finalize( pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, + a_scale: Optional[torch.Tensor], topk_weight: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -289,6 +291,7 @@ def _pplx_prepare_finalize( pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, + a_scale: Optional[torch.Tensor], score: torch.Tensor, topk: torch.Tensor, num_experts: int, @@ -316,7 +319,7 @@ def _pplx_prepare_finalize( topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( a.dtype) - pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, + pplx_output = pplx_prepare_finalize(pgi, dp_size, a, a_scale, topk_weight, topk_ids, num_experts, group_name) torch_output = chunk_by_rank(torch_output, pgi.rank, @@ -335,8 +338,10 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("per_act_token_quant", [False]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_prepare_finalize( @@ -345,16 +350,31 @@ def test_pplx_prepare_finalize( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]], use_internode: bool, ): + if dtype == torch.float8_e4m3fn: + use_fp8_w8a8 = True + act_dtype = torch.bfloat16 + else: + use_fp8_w8a8 = False + act_dtype = dtype + + if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None: + pytest.skip("Skip quantization test for non-quantized type") + current_platform.seed_everything(7) m, n, k = mnk world_size, dp_size = world_dp_size device = "cuda" - a = torch.randn((m, k), device=device, dtype=dtype) / 10 - score = torch.randn((m, e), device=device, dtype=dtype) - parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, + a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 + score = torch.randn((m, e), device=device, dtype=act_dtype) + + a, a_scale = moe_kernel_quantize_input(a, None, dtype, False, block_shape) + + parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, a_scale, score, topk, e, use_internode) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b315b4a97f04..4ef2948156ab 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -27,7 +27,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, expert_x_fp32 = expert_x_fp8.to(torch.float32).view( num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) - return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -39,12 +39,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # specific hidden sizes. SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168] - def __init__(self, - buffer: deep_ep.Buffer, - max_tokens_per_rank: int, - world_size: int, - dp_size: int, - use_fp8_dispatch: bool = False): + def __init__(self, buffer: deep_ep.Buffer, max_tokens_per_rank: int, + world_size: int, dp_size: int, use_fp8_dispatch: bool): super().__init__() self.buffer = buffer diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 2b1c97919562..c58be699e115 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -729,6 +729,19 @@ def apply( self.activation(activation, tmp, input) output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) +def maybe_fix_scales(scales: Optional[torch.Tensor], num_experts: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + scales = scales.view(1) + scales = torch.repeat_interleave( + scales, + num_experts, + dim=0 + ).view(num_experts, 1, 1) + else: + scales = scales.view(num_experts, -1, scales.size(-1)) + + return scales def batched_moe_kernel_quantize_input( A: torch.Tensor, @@ -757,13 +770,7 @@ def batched_moe_kernel_quantize_input( # num = expert_num_tokens[e] # A_q_scale[e, num:].fill_(0) - if A_q_scale is not None: - if A_q_scale.numel() == 1: - A_q_scale = A_q_scale.view(1) - A_q_scale = torch.repeat_interleave(A_q_scale, E, - dim=0).view(E, 1, 1) - else: - A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) + A_q_scale = maybe_fix_scales(A_q_scale, E) #print(f"A2Q_SCALE {A_q_scale.shape}\n{A_q_scale}") #A_q_scale.fill_(0.0001) @@ -960,6 +967,8 @@ def apply( intermediate_cache1.fill_(0) #print(f"A1_SCALES {a1q_scale.shape}") + a1q_scale = maybe_fix_scales(a1q_scale, E) + a2_scale = maybe_fix_scales(a2_scale, E) # MM1 invoke_moe_batched_triton_kernel( From b783ce6bd78e092020c384d7a013562af1a63859 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 21 Jun 2025 23:00:16 +0000 Subject: [PATCH 38/77] torch_experts working Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 52 +++++++++----- tests/kernels/moe/test_pplx_moe.py | 10 ++- .../model_executor/layers/fused_moe/config.py | 45 ++++++++++++ .../layers/fused_moe/fused_batched_moe.py | 72 ++++++++++--------- .../layers/fused_moe/pplx_prepare_finalize.py | 11 +-- 5 files changed, 133 insertions(+), 57 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index d4ce66a62419..eb94065d3aad 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -9,6 +9,7 @@ import triton.language as tl from tests.kernels.moe.utils import (batched_moe, + naive_batched_moe, make_quantized_test_activations, make_test_weights, triton_moe) from tests.kernels.quant_utils import native_batched_masked_quant_matmul @@ -135,7 +136,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, in_dtype=act_dtype, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant) + per_act_token_quant=per_act_token_quant, + ) B, B_q, B_scale, _, _, _ = make_test_weights( num_experts, @@ -144,6 +146,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, in_dtype=act_dtype, quant_dtype=quant_dtype, block_shape=block_shape, + per_act_token_quant=per_act_token_quant, ) out_shape = (num_experts, max_tokens_per_expert, N) @@ -203,16 +206,18 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.float32: (1e-2, 1e-2), }[test_output.dtype] - torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) - torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) + + #torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) + #torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) @pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) -@pytest.mark.parametrize("per_act_token_quant", [False]) -@pytest.mark.parametrize("block_shape", [None]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) def test_fused_moe_batched_experts( m: int, n: int, @@ -227,10 +232,13 @@ def test_fused_moe_batched_experts( use_fp8_w8a8 = dtype == torch.float8_e4m3fn + if topk > e: + pytest.skip("topk > e") + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): pytest.skip("Skip quantization test for non-quantized type") - if per_act_token_quant and block_shape is not None or topk > e: + if per_act_token_quant and block_shape is not None: pytest.skip("Skip illegal quantization test.") a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 @@ -243,12 +251,15 @@ def test_fused_moe_batched_experts( act_dtype = dtype quant_dtype = None - _, w1, w1_s, _, w2, w2_s = make_test_weights(e, - n, - k, - block_shape=block_shape, - in_dtype=act_dtype, - quant_dtype=quant_dtype) + w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( + e, + n, + k, + block_shape=block_shape, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + ) with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) @@ -293,12 +304,17 @@ def test_fused_moe_batched_experts( #print(f"TRITON {triton_output.shape}\n{triton_output}") #print(f"BATCHED {batched_output.shape}\n{batched_output}") - torch.testing.assert_close(triton_output, + torch.testing.assert_close(batched_output, baseline_output, - atol=2e-2, + atol=3e-2, rtol=2e-2) - torch.testing.assert_close(triton_output, - batched_output, - atol=2e-2, - rtol=2e-2) + # torch.testing.assert_close(triton_output, + # baseline_output, + # atol=2e-2, + # rtol=2e-2) + + # torch.testing.assert_close(triton_output, + # batched_output, + # atol=2e-2, + # rtol=2e-2) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 96d64b53dfee..6ee3d58a5a03 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -361,9 +361,12 @@ def test_pplx_prepare_finalize( use_fp8_w8a8 = False act_dtype = dtype - if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None: + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): pytest.skip("Skip quantization test for non-quantized type") + if per_act_token_quant and block_shape is not None: + pytest.skip("Skip illgal quantization combination") + current_platform.seed_everything(7) m, n, k = mnk world_size, dp_size = world_dp_size @@ -654,9 +657,12 @@ def test_pplx_moe( use_fp8_w8a8 = False quant_dtype = None - if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None: + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): pytest.skip("Skip quantization test for non-quantized type") + if per_act_token_quant and block_shape is not None: + pytest.skip("Skip illgal quantization combination") + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9a678406b8f3..47c88f4619f4 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -68,6 +68,51 @@ class FusedMoEQuantConfig: # TODO: add col major flag? # add detailed quant info for input, intermediates, weights, etc? + def __post_init__(self): + assert (not self.per_act_token_quant or + self.block_shape is None), "illegal quantization" + + @property + def is_quantized(self) -> bool: + return self.quant_dtype is not None + + @property + def is_per_act_token(self) -> bool: + return self.per_act_token_quant + + @property + def is_grouped(self) -> bool: + return self.block_shape is not None + + @property + def is_per_tensor(self) -> bool: + return not self.per_act_token_quant and self.block_shape is None + + def scale_shape(self, max_tokens: int, hidden_dim: int) -> Optional[tuple[int, int]]: + if self.is_quantized: + if self.is_grouped: + _, block_k = self.block_shape + k_tiles = cdiv(hidden_dim, block_k) + return (max_tokens, k_tiles) + elif self.is_per_act_token: + return (max_tokens, 1) + else: + return (1, 1) + else: + return None + + def batched_scale_shape( + self, + num_experts: int, + max_tokens: int, + hidden_dim: int + ) -> Optional[tuple[int, int, int]]: + if self.is_quantized: + scale_shape = self.scale_shape(max_tokens, hidden_dim) + return (num_experts, *scale_shape) + else: + return None + @staticmethod def make( use_fp8_w8a8: bool = False, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index c58be699e115..3a48ce357109 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -71,7 +71,7 @@ def moe_mmk( # channel-wise elif per_channel_quant: # TODO: probably not correct - b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn[ + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations @@ -526,19 +526,10 @@ def prepare( dtype=b_type, device=a1.device) - if quant_config.quant_dtype is not None: - if quant_config.block_shape is not None: - _, block_k = quant_config.block_shape - k_tiles = (hidden_dim + block_k - 1) // block_k - scale_shape = (num_local_experts, self.max_num_tokens, k_tiles) - else: - if quant_config.per_act_token_quant: - num = self.max_num_tokens - else: - num = 1 - scale_shape = (num_local_experts, num, 1) - - #print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}") + if quant_config.is_quantized: + scale_shape = quant_config.batched_scale_shape(num_local_experts, + self.max_num_tokens, + hidden_dim) b_a1_scale = torch.zeros(scale_shape, dtype=torch.float32, @@ -555,36 +546,25 @@ def prepare( rows = torch.count_nonzero(topks.flatten()) if rows == 0: continue - rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert + tokens_per_expert[idx] = rows + rhs = a1[:topks.numel()][topks] if quant_config.quant_dtype is not None: if a1_scale is not None: assert False, "NYI" rhs_a1_scale = a1_scale[:topks.numel()][topks] else: rhs_a1_scale = None - b_a1[idx, :rows, :], b_s = (moe_kernel_quantize_input( + b_a1[idx, :rows, :], b_a1_scale[idx] = (moe_kernel_quantize_input( rhs, rhs_a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, )) - if (quant_config.block_shape is None - and not quant_config.per_act_token_quant): - b_a1_scale[idx] = b_s - else: - #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}") - assert rows == b_s.shape[0] and b_a1_scale.shape[ - -1] == b_s.shape[-1] - b_a1_scale[idx, :rows] = b_s else: b_a1[idx, :rows, :] = rhs - tokens_per_expert[idx] = rows - - #b_a1_scale.fill_(0.0001) - #print(f"A1Q_scale = {b_a1_scale.shape}\n{b_a1_scale}") assert b_a1_scale is None or b_a1_scale.ndim == 3 return b_a1, b_a1_scale, tokens_per_expert, None, None @@ -645,7 +625,6 @@ def __init__( per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) - assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" @@ -684,6 +663,15 @@ def workspace_shapes( workspace2 = (self.max_num_tokens * num_dp, N) return (workspace13, workspace2, workspace13, a.dtype) + def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + assert self.quant_config.is_quantized + f32 = torch.float32 + if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: + return t.to(f32) * scale + else: + t32 = t.to(f32).view(-1, self.quant_config.block_shape[1]) + return (t32 * scale.view(-1, 1)).view(t.shape) + def apply( self, output: torch.Tensor, @@ -711,7 +699,12 @@ def apply( assert num_local_experts == w1.size(0), ( f"{num_local_experts} == {w1.size(0)}") + assert a2_scale is None, "NYI" + N = w1.size(1) // 2 + f32 = torch.float32 + + #output.fill_(0) for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor @@ -725,9 +718,23 @@ def apply( continue tmp = _resize_cache(workspace2, (num, N)) - input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) - self.activation(activation, tmp, input) - output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + + if self.quant_config.is_quantized: + input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) + w1_dq = self.dequant(w1[expert], w1_scale[expert]) + input = input[:num] @ w1_dq.transpose(0, 1) + else: + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + + self.activation(activation, tmp, input.to(tmp.dtype)) + + if self.quant_config.is_quantized: + w2_dq = self.dequant(w2[expert], w2_scale[expert]) + else: + w2_dq = w2[expert] + + output[expert, :num, :] = tmp @ w2_dq.transpose(0, 1).to(tmp.dtype) + def maybe_fix_scales(scales: Optional[torch.Tensor], num_experts: int) -> Optional[torch.Tensor]: if scales is not None: @@ -743,6 +750,7 @@ def maybe_fix_scales(scales: Optional[torch.Tensor], num_experts: int) -> Option return scales + def batched_moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index fdb124b18ea0..93abf769dca5 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -131,11 +131,13 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] # pad out scales if needed. TODO (bnell): do for non-scalar scales? - if scalar_scales: - #print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") + if False and scalar_scales: + print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") a1q_scale = a1q_scale.repeat(a1q.shape[1], 4 * torch.float32.itemsize) + a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) + #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}" assert a1q_scale is None or a1q_scale.ndim == 2, \ @@ -170,11 +172,10 @@ def prepare( expert_x_scale_shape = ( num_local_experts, expert_x.size(1), - (expert_x.size(2) + block_size - 1) // - block_size if not scalar_scales else 1, + cdiv(expert_x.size(2), block_size) if not scalar_scales else 1, ) - #print(f"EXPERT_X_SCALE {expert_x_scale_shape}") + print(f"EXPERT_X_SCALE {expert_x_scale_shape}") expert_x_scale = torch.zeros( expert_x_scale_shape, From 3ea14546c4e29ed4bf2410096b05a3129966c9d4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 22 Jun 2025 00:12:53 +0000 Subject: [PATCH 39/77] fp8 baselines working Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 17 ++++++++++------- vllm/model_executor/layers/fused_moe/config.py | 2 ++ .../layers/fused_moe/fused_batched_moe.py | 10 +++++++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 6ee3d58a5a03..3b93ed465f29 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -171,9 +171,9 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) + baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) # only for baseline torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) - batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this torch.testing.assert_close(baseline_output, torch_output, @@ -666,11 +666,14 @@ def test_pplx_moe( a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - _, w1, w1_s, _, w2, w2_s = make_test_weights(e, - n, - k, - quant_dtype=quant_dtype, - block_shape=block_shape) + _, w1, w1_s, _, w2, w2_s = make_test_weights( + e, + n, + k, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + ) parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 47c88f4619f4..9e8c5afc4de3 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -14,6 +14,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.utils import cdiv + logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 3a48ce357109..77d58582f93c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -13,6 +13,7 @@ get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast @triton.jit @@ -555,13 +556,17 @@ def prepare( rhs_a1_scale = a1_scale[:topks.numel()][topks] else: rhs_a1_scale = None - b_a1[idx, :rows, :], b_a1_scale[idx] = (moe_kernel_quantize_input( + b_a1[idx, :rows, :], b_s = (moe_kernel_quantize_input( rhs, rhs_a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, )) + if quant_config.is_per_tensor: + b_a1_scale[idx] = b_s + else: + b_a1_scale[idx, :rows] = b_s[:rows] else: b_a1[idx, :rows, :] = rhs @@ -669,8 +674,7 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: return t.to(f32) * scale else: - t32 = t.to(f32).view(-1, self.quant_config.block_shape[1]) - return (t32 * scale.view(-1, 1)).view(t.shape) + return t.to(f32) * group_broadcast(scale, t.shape) def apply( self, From 9d8dd1dcccced9dd7a65373bc40982b685f98e4d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 22 Jun 2025 03:11:26 +0000 Subject: [PATCH 40/77] mm baselines work Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index eb94065d3aad..2b5653f3f4da 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -190,15 +190,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, B, ref_output, num_expert_tokens, - None, - None, - None, ) q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output, num_expert_tokens, A_scale, B_scale, - block_shape) + block_shape, + per_act_token_quant) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -207,7 +205,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, }[test_output.dtype] torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) - #torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) #torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) From 5b376e5bcee00f439bccccb1b363169a2aeb4307 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 23 Jun 2025 21:59:02 +0000 Subject: [PATCH 41/77] prepare_finalize wokring Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 2 +- tests/kernels/moe/test_pplx_moe.py | 96 ++++++++++++++----- .../layers/fused_moe/pplx_prepare_finalize.py | 45 ++++++--- 3 files changed, 104 insertions(+), 39 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 2b5653f3f4da..0810b5c4cb04 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -206,7 +206,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) #torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) - #torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) @pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 3b93ed465f29..eea4293cc6c7 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -20,6 +20,7 @@ from tests.kernels.moe.utils import make_test_weights, naive_batched_moe from tests.kernels.utils import torch_experts +from tests.kernels.quant_utils import dequant from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -40,8 +41,14 @@ reason="Requires PPLX kernels", ) -PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), - (222, 2048, 1024)] +PPLX_PREPARE_COMBOS = [ +# (1, 128, 128), + (4, 128, 128), + (32, 1024, 512), +# (45, 512, 2048), + (64, 1024, 512), + (222, 2048, 1024), +] PPLX_MOE_COMBOS = [ (1, 128, 128), @@ -195,18 +202,24 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: return t[(r * chunk):(r + 1) * chunk] +def dummy_work(a: torch.Tensor) -> torch.Tensor: + return a # * 1.5 + + def pplx_prepare_finalize( pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, - a_scale: Optional[torch.Tensor], topk_weight: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, + quant_dtype: Optional[torch.dtype], + block_shape: Optional[list[int]], + per_act_token_quant: bool, group_name: Optional[str], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) assert torch.cuda.current_device() == pgi.local_rank @@ -215,7 +228,16 @@ def pplx_prepare_finalize( device = pgi.device rank = pgi.rank world_size = pgi.world_size - max_num_tokens = rank_chunk(num_tokens, 0, world_size) + max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1) + + hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( + max_num_tokens, + hidden_dim, + a.dtype, + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) args = dict( max_num_tokens=max_num_tokens, @@ -225,8 +247,8 @@ def pplx_prepare_finalize( world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=0, + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=scale_bytes, ) if group_name is None: @@ -258,10 +280,17 @@ def pplx_prepare_finalize( num_experts, None, False, - FusedMoEQuantConfig(), + FusedMoEQuantConfig( + quant_dtype, + per_act_token_quant, + False, + block_shape, + ), ) - b_a = b_a * 1.5 + # Do some fake work + #print(f"INTER {b_a.shape} {b_a_scale.shape if b_a_scale is not None else None}") + b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) out = torch.full( (max_num_tokens, hidden_dim), @@ -291,10 +320,12 @@ def _pplx_prepare_finalize( pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, - a_scale: Optional[torch.Tensor], score: torch.Tensor, topk: torch.Tensor, num_experts: int, + quant_dtype: Optional[torch.dtype], + block_shape: Optional[list[int]], + per_act_token_quant: bool, use_internode: bool, ): if use_internode: @@ -308,24 +339,35 @@ def _pplx_prepare_finalize( cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name - device = pgi.device + #device = pgi.device topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - k = a.shape[1] + m, k = a.shape - a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) + a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) #.to(device) - torch_output = (a_rep.view(-1, topk, k) * 1.5 * - topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( - a.dtype) + if True: + torch_output = (a_rep.view(m, topk, k) * + topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1) + else: + import vllm._custom_ops as ops + a_rep = a_rep.view(m, topk, k) + a_rep.mul_(topk_weight.view(m, topk, 1).to(a_rep.dtype)) + torch_output = torch.empty_like(a) + ops.moe_sum(a_rep, torch_output) - pplx_output = pplx_prepare_finalize(pgi, dp_size, a, a_scale, topk_weight, topk_ids, - num_experts, group_name) + pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, + num_experts, quant_dtype, block_shape, + per_act_token_quant, group_name) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + #torch.set_printoptions(profile="full") + #print(f"PPLX {pplx_output.shape}\n{pplx_output.shape}") + #print(f"TORCH {torch_output.shape}\n{torch_output.shape}") + + torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) if use_internode: nvshmem_finalize() @@ -334,7 +376,6 @@ def _pplx_prepare_finalize( # TODO (bnell): this test point does not work for odd M due to how the test is # written, not due to limitations of the pplx kernels. The pplx_moe # test below is able to deal with odd M. -# TODO (bnell) add fp8 tests @pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @@ -357,28 +398,31 @@ def test_pplx_prepare_finalize( if dtype == torch.float8_e4m3fn: use_fp8_w8a8 = True act_dtype = torch.bfloat16 + quant_dtype = dtype else: use_fp8_w8a8 = False act_dtype = dtype + quant_dtype = None if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): pytest.skip("Skip quantization test for non-quantized type") if per_act_token_quant and block_shape is not None: - pytest.skip("Skip illgal quantization combination") + pytest.skip("Skip illegal quantization combination") current_platform.seed_everything(7) m, n, k = mnk world_size, dp_size = world_dp_size device = "cuda" + #print(f"MNK = {mnk}") + a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 score = torch.randn((m, e), device=device, dtype=act_dtype) - a, a_scale = moe_kernel_quantize_input(a, None, dtype, False, block_shape) - - parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, a_scale, score, - topk, e, use_internode) + parallel_launch(world_size, _pplx_prepare_finalize, dp_size, + a, score, topk, e, quant_dtype, block_shape, + per_act_token_quant, use_internode) def pplx_moe( @@ -661,7 +705,7 @@ def test_pplx_moe( pytest.skip("Skip quantization test for non-quantized type") if per_act_token_quant and block_shape is not None: - pytest.skip("Skip illgal quantization combination") + pytest.skip("Skip illegal quantization combination") a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 93abf769dca5..bfce386d8538 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -120,6 +120,17 @@ def prepare( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) + if quant_config.quant_dtype is not None: + if quant_config.is_per_tensor: + assert a1q_scale.numel() == 1 + elif quant_config.is_per_act_token: + assert a1q_scale.numel() == a1.numel() + assert a1q_scale.shape == a1.shape + else: + assert a1q_scale.numel() == a1.shape[0] * cdiv(a1.shape[1], quant_config.block_shape[1]) + assert a1q_scale.shape == (a1.shape[0], cdiv(a1.shape[1], quant_config.block_shape[1])) + #a1q_scale = group_broadcast(scale, a1q.shape) + if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -131,15 +142,21 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] # pad out scales if needed. TODO (bnell): do for non-scalar scales? - if False and scalar_scales: - print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") - a1q_scale = a1q_scale.repeat(a1q.shape[1], - 4 * torch.float32.itemsize) + if False and (scalar_scales or quant_config.is_per_tensor): + #print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") + a1q_scale = a1q_scale.repeat(1, 4 * torch.float32.itemsize) + else: + #a1q_scale = torch.repeat_interleave(a1q_scale, round_up(a1q_scale.shape[1], 16), dim=1) + #a1q_scale = torch.nn.functional.pad(a1q_scale, pad=(0, 16-a1q_scale.shape[1]), mode='replicate') + pass - a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) + if not quant_config.is_grouped: + a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}" + #print(f"FINAL SCALE SHAPE {a1q_scale.shape}") + assert a1q_scale is None or a1q_scale.ndim == 2, \ f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" @@ -166,16 +183,23 @@ def prepare( expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: float32_size = torch.float32.itemsize - block_size = (quant_config.block_shape[1] if quant_config. - block_shape is not None else 1) * float32_size + + if quant_config.is_per_act_token: + final_dim = expert_x.size(2) + assert final_dim % 4 == 0 #? + elif quant_config.is_per_tensor: + final_dim = 4 + else: + num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) + final_dim = round_up(num_blocks, 4) expert_x_scale_shape = ( num_local_experts, expert_x.size(1), - cdiv(expert_x.size(2), block_size) if not scalar_scales else 1, + final_dim, ) - print(f"EXPERT_X_SCALE {expert_x_scale_shape}") + #print(f"EXPERT_X_SCALE {expert_x_scale_shape}") expert_x_scale = torch.zeros( expert_x_scale_shape, @@ -200,9 +224,6 @@ def prepare( ) #print(f"DISPATCH DONE {device}") - if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] - if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 From 0c06d4b7cf388b71ed1d365195384d426192ac48 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 24 Jun 2025 02:55:46 +0000 Subject: [PATCH 42/77] per token + grouped broken Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 4 +-- .../layers/fused_moe/fused_batched_moe.py | 27 +++++++++---------- .../layers/fused_moe/pplx_prepare_finalize.py | 2 +- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 0810b5c4cb04..ee831dda57a4 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -100,8 +100,8 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("block_shape", [None, [128, 128]]) -@pytest.mark.parametrize("per_act_token_quant", [False, True]) +@pytest.mark.parametrize("block_shape", [None])#, [128, 128]]) +@pytest.mark.parametrize("per_act_token_quant", [False])#, True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype, block_shape: Optional[list[int]], diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 77d58582f93c..e25b80a0ba54 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -63,17 +63,15 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm - ) #+ (expert_id * stride_ase) + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm #+ (expert_id * stride_ase) offs_bsn = offs_n // group_n - b_scale_ptrs = (b_scale_ptr + - offs_bsn * stride_bsn) + expert_id * stride_bse + b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + + offs_bsn * stride_bsn) # channel-wise elif per_channel_quant: # TODO: probably not correct - b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ - None, :] * stride_bsn + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations # + (expert_id * stride_ase)?? @@ -300,16 +298,14 @@ def batched_triton_kernel( cta_n_start * stride_cn) if use_fp8_w8a8: + a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) # block-wise - if (group_k > 0 and group_n > 0) or per_channel_quant: - a_scale_ptr = a_scale_ptr + (expert_id * - stride_ase) + cta_m_start * stride_asm - #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) - # (?) b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn - # channel-wise or tensor-wise - else: - a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) - #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + if group_k > 0 and group_n > 0: + a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm + b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + elif per_channel_quant: + a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm + b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + cta_n_start * stride_bsn expert_triton_kernel( a_ptr, @@ -532,6 +528,7 @@ def prepare( self.max_num_tokens, hidden_dim) + # empty? b_a1_scale = torch.zeros(scale_shape, dtype=torch.float32, device=a1.device) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index bfce386d8538..fa5df509d5d4 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -129,7 +129,6 @@ def prepare( else: assert a1q_scale.numel() == a1.shape[0] * cdiv(a1.shape[1], quant_config.block_shape[1]) assert a1q_scale.shape == (a1.shape[0], cdiv(a1.shape[1], quant_config.block_shape[1])) - #a1q_scale = group_broadcast(scale, a1q.shape) if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -201,6 +200,7 @@ def prepare( #print(f"EXPERT_X_SCALE {expert_x_scale_shape}") + # empty? expert_x_scale = torch.zeros( expert_x_scale_shape, dtype=torch.float32, From 8d4e287db6ae36d031c747bd10baa57206ccd9a0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 24 Jun 2025 17:59:20 +0000 Subject: [PATCH 43/77] a scales working, b scales not working Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 14 +++- .../layers/fused_moe/fused_batched_moe.py | 77 ++++++++++++------- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index ee831dda57a4..ed1f7ce94e93 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -100,8 +100,8 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("block_shape", [None])#, [128, 128]]) -@pytest.mark.parametrize("per_act_token_quant", [False])#, True]) +@pytest.mark.parametrize("block_shape", [[128, 128]]) # [None])#, [128, 128]]) +@pytest.mark.parametrize("per_act_token_quant", [False, True])# [False])# ,True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype, block_shape: Optional[list[int]], @@ -162,6 +162,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, assert A_q.dtype == B_q.dtype + #A_scale.fill_(1) + B_scale.fill_(1) + invoke_moe_batched_triton_kernel( A_q, B_q, @@ -204,7 +207,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.float32: (1e-2, 1e-2), }[test_output.dtype] - torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) + if False: + torch.set_printoptions(profile="full") + print(f"REF_OUTPUT {q_ref_output.shape}\n{q_ref_output}") + print(f"TRITON {test_output.shape}\n{test_output}") + + #torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) #torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index e25b80a0ba54..311fce3a3429 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -50,7 +50,7 @@ def moe_mmk( compute_type: tl.constexpr, use_w8a8: tl.constexpr, use_w8a16: tl.constexpr, - per_channel_quant: tl.constexpr, + per_act_token_quant: tl.constexpr, ): offs_k = tl.arange(0, BLOCK_K) @@ -63,25 +63,33 @@ def moe_mmk( if use_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + offs_m * stride_asm #+ (expert_id * stride_ase) + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm offs_bsn = offs_n // group_n - b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + - offs_bsn * stride_bsn) + b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn - # channel-wise - elif per_channel_quant: - # TODO: probably not correct - b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn + # per act token + elif per_act_token_quant: + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None] + + b_scale_ptrs = b_scale_ptr + offs_n[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) + + # Load per-token scale for activations # + (expert_id * stride_ase)?? - a_scale_ptrs = a_scale_ptr + offs_m * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] + #a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + #a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] + + # TODO: probably not correct + #b_scale_ptrs = b_scale_ptr + expert_id * stride_bse #+ offs_n[None, :] * stride_bsn + #b_scale = tl.load(b_scale_ptrs) # tensor-wise else: - a_scale = tl.load(a_scale_ptr) # + (expert_id * stride_ase) - b_scale = tl.load(b_scale_ptr + expert_id * stride_bse) + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -108,26 +116,33 @@ def moe_mmk( other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, - None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + elif False and per_act_token_quant: + a_scale = tl.load(a_scale_ptrs + offs_k[None, :] * stride_ask, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: - if use_w8a8: - # acc used to enable fp8_fast_accum - accumulator = tl.dot(a, b, acc=accumulator) - else: - accumulator += tl.dot(a, b) + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk + if False and per_act_token_quant: + a_scale_ptrs += BLOCK_K * stride_ask + b_scale_ptrs += BLOCK_K * stride_bsk + if use_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) - else: + elif True or not per_act_token_quant: accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) @@ -169,7 +184,7 @@ def expert_triton_kernel( # Quantization schemes use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, - per_channel_quant: tl.constexpr, + per_act_token_quant: tl.constexpr, # Kernel config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -181,6 +196,7 @@ def expert_triton_kernel( offs_k = tl.arange(0, BLOCK_K) mask_m = offs_m < M + # Make grids of a + b pointers a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn @@ -217,7 +233,7 @@ def expert_triton_kernel( compute_type, use_fp8_w8a8, use_int8_w8a16, - per_channel_quant) + per_act_token_quant) # store in C offs_cn = tl.arange(0, BLOCK_N) @@ -266,17 +282,19 @@ def batched_triton_kernel( # Quantization schemes use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, - per_channel_quant: tl.constexpr, + per_act_token_quant: tl.constexpr, # Kernel config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr): + BLOCK_K: tl.constexpr, +): expert_id = tl.program_id(axis=0) e_num_tokens = tl.load(expert_num_tokens + expert_id) if e_num_tokens == 0: # Early exit return + # axis 1 is M_blocks * N_blocks pid_mn = tl.program_id(axis=1) #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) @@ -298,14 +316,15 @@ def batched_triton_kernel( cta_n_start * stride_cn) if use_fp8_w8a8: - a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + a_scale_ptr = a_scale_ptr + expert_id * stride_ase + b_scale_ptr = b_scale_ptr + expert_id * stride_bse # block-wise if group_k > 0 and group_n > 0: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm - b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) - elif per_channel_quant: + # b group advancement? + elif False and per_act_token_quant: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm - b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + cta_n_start * stride_bsn + b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn expert_triton_kernel( a_ptr, @@ -338,7 +357,7 @@ def batched_triton_kernel( # Quantization schemes use_fp8_w8a8, use_int8_w8a16, - per_channel_quant, + per_act_token_quant, # Kernel config BLOCK_M, BLOCK_N, From 62404e3b328d16a1c0321370957d2c0e1a6e0c95 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 24 Jun 2025 19:22:25 +0000 Subject: [PATCH 44/77] blocked working Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 5 ++--- .../layers/fused_moe/fused_batched_moe.py | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index ed1f7ce94e93..e3e4ee94d306 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -100,7 +100,7 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("block_shape", [[128, 128]]) # [None])#, [128, 128]]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) # [None])#, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True])# [False])# ,True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype, @@ -162,8 +162,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, assert A_q.dtype == B_q.dtype - #A_scale.fill_(1) - B_scale.fill_(1) + #B_scale.fill_(0.5) invoke_moe_batched_triton_kernel( A_q, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 311fce3a3429..4bd58e92399f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -39,6 +39,7 @@ def moe_mmk( # Offsets and masks offs_m, offs_n, + offs_bn, mask_m, # Block size for block-wise quantization group_n: tl.constexpr, @@ -64,7 +65,7 @@ def moe_mmk( # block-wise if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + offs_m * stride_asm - offs_bsn = offs_n // group_n + offs_bsn = offs_bn // group_n b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn # per act token @@ -142,7 +143,7 @@ def moe_mmk( elif use_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) - elif True or not per_act_token_quant: + else: #if True or not per_act_token_quant: accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) @@ -178,6 +179,8 @@ def expert_triton_kernel( stride_bse, stride_bsk, stride_bsn, + # offsets + offs_bn, # Blockwise quantization data group_n, group_k, @@ -222,6 +225,7 @@ def expert_triton_kernel( # Offsets and masks offs_m, offs_n, + offs_bn, mask_m, # Block size for block-wise quantization group_n, @@ -315,12 +319,15 @@ def batched_triton_kernel( c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + cta_n_start * stride_cn) + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N + if use_fp8_w8a8: a_scale_ptr = a_scale_ptr + expert_id * stride_ase b_scale_ptr = b_scale_ptr + expert_id * stride_bse # block-wise if group_k > 0 and group_n > 0: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm + #b_scale_ptr = b_scale_ptr + offs_bn * stride_bsn # b group advancement? elif False and per_act_token_quant: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm @@ -351,6 +358,8 @@ def batched_triton_kernel( stride_bse, stride_bsk, stride_bsn, + # offsets + offs_bn, # Blockwise quantization data group_n, group_k, @@ -404,12 +413,13 @@ def invoke_moe_batched_triton_kernel( if B_scale is not None: if B_scale.ndim == 1: stride_bse = 1 - stride_bsn = 0 stride_bsk = 0 + stride_bsn = 0 else: stride_bse = B_scale.stride(0) - stride_bsn = B_scale.stride(1) stride_bsk = B_scale.stride(2) + stride_bsn = B_scale.stride(1) + else: stride_bse = 0 stride_bsk = 0 From b4dc46e6e926308fd64e3634b480404e69d3ee5e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 24 Jun 2025 20:22:12 +0000 Subject: [PATCH 45/77] per_act_token working Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 16 +++++++-------- .../layers/fused_moe/fused_batched_moe.py | 20 +++---------------- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index e3e4ee94d306..d7aa7bf7aec3 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -100,8 +100,8 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("block_shape", [None, [128, 128]]) # [None])#, [128, 128]]) -@pytest.mark.parametrize("per_act_token_quant", [False, True])# [False])# ,True]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype, block_shape: Optional[list[int]], @@ -162,8 +162,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, assert A_q.dtype == B_q.dtype - #B_scale.fill_(0.5) - invoke_moe_batched_triton_kernel( A_q, B_q, @@ -211,7 +209,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, print(f"REF_OUTPUT {q_ref_output.shape}\n{q_ref_output}") print(f"TRITON {test_output.shape}\n{test_output}") - #torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) #torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) @@ -318,7 +316,7 @@ def test_fused_moe_batched_experts( # atol=2e-2, # rtol=2e-2) - # torch.testing.assert_close(triton_output, - # batched_output, - # atol=2e-2, - # rtol=2e-2) + torch.testing.assert_close(triton_output, + batched_output, + atol=2e-2, + rtol=2e-2) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 4bd58e92399f..b2a655991754 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -74,19 +74,9 @@ def moe_mmk( a_scale_ptrs = a_scale_ptr + offs_m * stride_asm a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None] - b_scale_ptrs = b_scale_ptr + offs_n[None, :] * stride_bsn + b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) - - # Load per-token scale for activations - # + (expert_id * stride_ase)?? - #a_scale_ptrs = a_scale_ptr + offs_m * stride_asm - #a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] - - # TODO: probably not correct - #b_scale_ptrs = b_scale_ptr + expert_id * stride_bse #+ offs_n[None, :] * stride_bsn - #b_scale = tl.load(b_scale_ptrs) - # tensor-wise else: a_scale = tl.load(a_scale_ptr) @@ -134,10 +124,6 @@ def moe_mmk( a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk - if False and per_act_token_quant: - a_scale_ptrs += BLOCK_K * stride_ask - b_scale_ptrs += BLOCK_K * stride_bsk - if use_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_w8a8: @@ -329,9 +315,9 @@ def batched_triton_kernel( a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm #b_scale_ptr = b_scale_ptr + offs_bn * stride_bsn # b group advancement? - elif False and per_act_token_quant: + elif per_act_token_quant: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm - b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn + # b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn expert_triton_kernel( a_ptr, From c3bddec234ce691854c78f38e8162d219d325fdb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 25 Jun 2025 03:09:45 +0000 Subject: [PATCH 46/77] qwen works, rh-ds broken now, pplx_moe tests not all working Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 16 +-- tests/kernels/moe/test_pplx_moe.py | 131 ++++++------------ tests/kernels/quant_utils.py | 17 +++ .../model_executor/layers/fused_moe/config.py | 2 + .../fused_moe/deepep_ll_prepare_finalize.py | 10 +- .../layers/fused_moe/fused_batched_moe.py | 46 ++---- .../layers/fused_moe/fused_moe.py | 88 ------------ .../layers/fused_moe/pplx_prepare_finalize.py | 50 ++----- vllm/model_executor/layers/fused_moe/utils.py | 19 +++ .../compressed_tensors_moe.py | 2 +- 10 files changed, 114 insertions(+), 267 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index d7aa7bf7aec3..467eddadf2b8 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -204,13 +204,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.float32: (1e-2, 1e-2), }[test_output.dtype] - if False: - torch.set_printoptions(profile="full") - print(f"REF_OUTPUT {q_ref_output.shape}\n{q_ref_output}") - print(f"TRITON {test_output.shape}\n{test_output}") - torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) - #torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) @@ -277,6 +271,7 @@ def test_fused_moe_batched_experts( per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) + baseline_output = torch_experts( a, w1, @@ -302,20 +297,11 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - #print(f"TORCH {baseline_output.shape}\n{baseline_output}") - #print(f"TRITON {triton_output.shape}\n{triton_output}") - #print(f"BATCHED {batched_output.shape}\n{batched_output}") - torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2) - # torch.testing.assert_close(triton_output, - # baseline_output, - # atol=2e-2, - # rtol=2e-2) - torch.testing.assert_close(triton_output, batched_output, atol=2e-2, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index eea4293cc6c7..23dd8eeb3297 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -20,7 +20,7 @@ from tests.kernels.moe.utils import make_test_weights, naive_batched_moe from tests.kernels.utils import torch_experts -from tests.kernels.quant_utils import dequant +from tests.kernels.quant_utils import batched_dequant from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -42,16 +42,19 @@ ) PPLX_PREPARE_COMBOS = [ -# (1, 128, 128), + # TODO: figure out why this fails + #(1, 128, 128), + (2, 128, 512), + (3, 1024, 2048), (4, 128, 128), (32, 1024, 512), -# (45, 512, 2048), + (45, 512, 2048), (64, 1024, 512), (222, 2048, 1024), ] PPLX_MOE_COMBOS = [ - (1, 128, 128), +# (1, 128, 128), (2, 128, 512), (3, 1024, 2048), (32, 128, 1024), @@ -203,7 +206,7 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: def dummy_work(a: torch.Tensor) -> torch.Tensor: - return a # * 1.5 + return a * 1.1 def pplx_prepare_finalize( @@ -271,6 +274,13 @@ def pplx_prepare_finalize( chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( a_chunk, None, @@ -288,16 +298,10 @@ def pplx_prepare_finalize( ), ) - # Do some fake work - #print(f"INTER {b_a.shape} {b_a_scale.shape if b_a_scale is not None else None}") - b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) + #print(f"B_A_SCALE = {b_a.shape}, {b_a_scale.shape if b_a_scale is not None else None}, {per_act_token_quant} {block_shape}, {a_chunk.shape}") + # TOOD: shouldn't need batched_dequant - out = torch.full( - (max_num_tokens, hidden_dim), - torch.nan, - dtype=a.dtype, - device=device, - ) + b_a = dummy_work(batched_dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) prepare_finalize.finalize( out, @@ -339,33 +343,19 @@ def _pplx_prepare_finalize( cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name - #device = pgi.device - topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) m, k = a.shape - a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) #.to(device) + a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) - if True: - torch_output = (a_rep.view(m, topk, k) * - topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1) - else: - import vllm._custom_ops as ops - a_rep = a_rep.view(m, topk, k) - a_rep.mul_(topk_weight.view(m, topk, 1).to(a_rep.dtype)) - torch_output = torch.empty_like(a) - ops.moe_sum(a_rep, torch_output) + torch_output = (a_rep.view(m, topk, k) * + topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1) pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, num_experts, quant_dtype, block_shape, per_act_token_quant, group_name) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) - - #torch.set_printoptions(profile="full") - #print(f"PPLX {pplx_output.shape}\n{pplx_output.shape}") - #print(f"TORCH {torch_output.shape}\n{torch_output.shape}") + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pgi.device) torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) @@ -373,15 +363,14 @@ def _pplx_prepare_finalize( nvshmem_finalize() -# TODO (bnell): this test point does not work for odd M due to how the test is -# written, not due to limitations of the pplx kernels. The pplx_moe -# test below is able to deal with odd M. +# TODO (bnell): this test point does not work for M==1 due to how the test +# is written, not due to limitations of the pplx kernels. @pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) -@pytest.mark.parametrize("per_act_token_quant", [False]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx @@ -415,8 +404,6 @@ def test_pplx_prepare_finalize( world_size, dp_size = world_dp_size device = "cuda" - #print(f"MNK = {mnk}") - a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 score = torch.randn((m, e), device=device, dtype=act_dtype) @@ -510,8 +497,12 @@ def pplx_moe( w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) if w1_scale is not None: - w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) - w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) + if not per_act_token_quant: + w1_scale_chunk = w1_scale + w2_scale_chunk = w2_scale + else: + w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) + w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) else: w1_scale_chunk = None w2_scale_chunk = None @@ -562,48 +553,6 @@ def pplx_moe( return out -def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): - assert torch.cuda.current_device() == pgi.local_rank - - num_experts = w1.shape[0] - device = pgi.device - rank = pgi.rank - world_size = pgi.world_size - max_num_tokens = rank_chunk(a.shape[0], 0, world_size) - - prepare_finalize = BatchedPrepareAndFinalize( - max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - rank=rank, - ) - - experts = NaiveBatchedExperts(max_num_tokens=a.shape[0], - world_size=1, - dp_size=1) - - fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - # Note: workers with the same dp_rank must use the exact same inputs. - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) - chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) - - out = fused_experts( - a_chunk, - # Chunking weights like this only works for batched format - chunk_by_rank(w1, rank, world_size).to(device), - chunk_by_rank(w2, rank, world_size).to(device), - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_experts) - - return out - - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -654,18 +603,22 @@ def _pplx_moe( quant_dtype=qtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape) + pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, qtype, per_act_token_quant, block_shape) - # TODO (bnell): fix + re-enable - #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, - # topk_ids) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) + # all reduce on pplx? + #torch.distributed.all_reduce(pplx_output) + + batched_output = naive_batched_moe(a, w1, w2, topk_weight, + topk_ids, w1_s, w2_s, qtype, per_act_token_quant, block_shape) + + chunked_torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2) + #torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) if use_internode: nvshmem_finalize() diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index d0dc85f25755..2970a7c9af61 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -277,6 +277,23 @@ def dequant( return t.to(out_dtype) +def batched_dequant( + t: torch.Tensor, scale: + Optional[torch.Tensor], + block_shape: Optional[list[int]], + per_act_token_quant: bool, + out_dtype: Optional[torch.dtype] = torch.float32, +) -> torch.Tensor: + if scale is not None: + assert t.shape[0] == scale.shape[0] + out = torch.empty_like(t, dtype=out_dtype) + for e in range(t.shape[0]): + out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, out_dtype) + return out + + return t.to(out_dtype) + + def native_batched_masked_quant_matmul( A: torch.Tensor, B: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9e8c5afc4de3..c5aca64b8f2b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -310,6 +310,8 @@ def __post_init__(self): logger.debug("Using FusedMoEConfig::max_num_tokens=%d", self.max_num_tokens) + assert self.max_num_tokens > 0 + @property def quant_dtype(self) -> Optional[torch.dtype]: if self.quant_config is not None: diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 4ef2948156ab..b315b4a97f04 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -27,7 +27,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, expert_x_fp32 = expert_x_fp8.to(torch.float32).view( num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) - return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -39,8 +39,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # specific hidden sizes. SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168] - def __init__(self, buffer: deep_ep.Buffer, max_tokens_per_rank: int, - world_size: int, dp_size: int, use_fp8_dispatch: bool): + def __init__(self, + buffer: deep_ep.Buffer, + max_tokens_per_rank: int, + world_size: int, + dp_size: int, + use_fp8_dispatch: bool = False): super().__init__() self.buffer = buffer diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b2a655991754..713f22211b68 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -107,15 +107,9 @@ def moe_mmk( other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] - elif False and per_act_token_quant: - a_scale = tl.load(a_scale_ptrs + offs_k[None, :] * stride_ask, - mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), - other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) - accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: + # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) @@ -129,7 +123,7 @@ def moe_mmk( elif use_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) - else: #if True or not per_act_token_quant: + else: accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) @@ -310,14 +304,12 @@ def batched_triton_kernel( if use_fp8_w8a8: a_scale_ptr = a_scale_ptr + expert_id * stride_ase b_scale_ptr = b_scale_ptr + expert_id * stride_bse + # block-wise if group_k > 0 and group_n > 0: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm - #b_scale_ptr = b_scale_ptr + offs_bn * stride_bsn - # b group advancement? elif per_act_token_quant: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm - # b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn expert_triton_kernel( a_ptr, @@ -377,8 +369,6 @@ def invoke_moe_batched_triton_kernel( per_act_token_quant: bool, block_shape: Optional[list[int]] = None): - #print(f"TRITON MOE BATCHED {use_fp8_w8a8}, {per_act_token_quant}, {block_shape}") - assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) @@ -543,8 +533,7 @@ def prepare( self.max_num_tokens, hidden_dim) - # empty? - b_a1_scale = torch.zeros(scale_shape, + b_a1_scale = torch.empty(scale_shape, dtype=torch.float32, device=a1.device) else: @@ -720,8 +709,6 @@ def apply( N = w1.size(1) // 2 f32 = torch.float32 - #output.fill_(0) - for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor if (torch.compiler.is_compiling() @@ -778,28 +765,20 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # TODO: fix this if (True or torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) - assert A_scale is None or A_scale.ndim <= 2 + assert A_scale is None or A_scale.ndim <= 2, f"{A_scale.shape if A_scale is not None else None}" A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, hidden_dim), A_scale, qtype, per_act_token_quant, block_shape) A_q = A_q.view(E, -1, hidden_dim) - - # for e in range(len(expert_num_tokens)): - # num = expert_num_tokens[e] - # A_q_scale[e, num:].fill_(0) - A_q_scale = maybe_fix_scales(A_q_scale, E) - #print(f"A2Q_SCALE {A_q_scale.shape}\n{A_q_scale}") - #A_q_scale.fill_(0.0001) - #print(f"A_q_scale.stride = {A_q_scale.stride()}") - return A_q, A_q_scale if qtype is not None: @@ -990,9 +969,8 @@ def apply( if self.use_fp8_w8a8: intermediate_cache1.fill_(0) - #print(f"A1_SCALES {a1q_scale.shape}") a1q_scale = maybe_fix_scales(a1q_scale, E) - a2_scale = maybe_fix_scales(a2_scale, E) + #a2_scale = maybe_fix_scales(a2_scale, E) # MM1 invoke_moe_batched_triton_kernel( @@ -1016,21 +994,19 @@ def apply( # TODO: would be nice to use expert_num_tokens here to reduce # garbage compute if False: - # TODO: check expert_num_tokens tmp = torch.empty_like(intermediate_cache2[0]) for e in range(E): num_tokens = expert_num_tokens[e] - self.activation(activation, tmp[:num_tokens], - intermediate_cache1[e, :num_tokens]) - intermediate_cache2[e, :num_tokens] = tmp[:num_tokens] + if num_tokens > 0: + self.activation(activation, tmp[:num_tokens], + intermediate_cache1[e, :num_tokens]) + intermediate_cache2[e, :num_tokens] = tmp[:num_tokens] else: self.activation( activation, intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N)) - #print(f"BATCHED ACT {intermediate_cache2.shape}\n{intermediate_cache2}") - qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( intermediate_cache2, a2_scale, max_num_tokens, E, N, expert_num_tokens, self.quant_dtype, self.per_act_token_quant, self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 64d24f86627b..75712b8e3a4d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -466,72 +466,6 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def prepare_scales( - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - quant_dtype: Optional[torch.dtype], - block_shape: Optional[list[int]], - msg: str, -): - from vllm.utils import round_up - max_num_tokens = round_up(a1.shape[0], 64) - num_tokens, hidden_dim = a1.size() - #topk = topk_ids.size(1) - - tokens_per_expert = torch.zeros(num_experts, - dtype=torch.int, - device=a1.device) - - num_local_experts = num_experts - - b_a1 = torch.zeros( - (num_local_experts, max_num_tokens, hidden_dim), - dtype=quant_dtype if quant_dtype is not None else a1.dtype, - device=a1.device) - - if quant_dtype is not None: - if block_shape is not None: - _, block_k = block_shape - k_tiles = (hidden_dim + block_k - 1) // block_k - scale_shape = (num_local_experts, max_num_tokens, k_tiles) - else: - num = 1 - scale_shape = (num_local_experts, num, 1) - - b_a1_scale = torch.zeros(scale_shape, - dtype=torch.float32, - device=a1.device) - else: - assert a1_scale is None - b_a1_scale = None - - first_expert = 0 - last_expert = first_expert + num_local_experts - - for expert_id in range(first_expert, last_expert): - topks = torch.any(topk_ids == expert_id, dim=1).flatten() - rows = torch.count_nonzero(topks.flatten()) - rhs = a1[:topks.numel()][topks] - idx = expert_id - first_expert - b_a1[idx, :rows, :] = rhs - if quant_dtype is not None: - rhs_a1_scale = a1_scale[:topks.numel()][topks] - if block_shape is None: - b_a1_scale[idx] = rhs_a1_scale - else: - assert rows == rhs_a1_scale.shape[0] and b_a1_scale.shape[ - -1] == rhs_a1_scale.shape[-1] - b_a1_scale[idx, :rows] = rhs_a1_scale - - tokens_per_expert[idx] = rows - - print(f"{msg} {b_a1_scale.shape}\n{b_a1_scale}") - - return b_a1, b_a1_scale, tokens_per_expert - - def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @@ -1396,17 +1330,6 @@ def fused_experts_impl( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - if False: - prepare_scales( - qcurr_hidden_states, - a1q_scale, - curr_topk_ids, - global_num_experts, - torch.float8_e4m3fn if use_fp8_w8a8 else None, - block_shape, - "First", - ) - invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, @@ -1444,17 +1367,6 @@ def fused_experts_impl( per_act_token_quant=per_channel_quant, block_shape=block_shape) - if False: - prepare_scales( - qintermediate_cache2, - a2q_scale, - curr_topk_ids, - global_num_experts, - torch.float8_e4m3fn if use_fp8_w8a8 else None, - block_shape, - "Second", - ) - invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index fa5df509d5d4..a93afc82a00f 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -8,7 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + moe_kernel_quantize_input, _validate_scale_shape) from vllm.utils import cdiv, round_up @@ -120,15 +120,12 @@ def prepare( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) - if quant_config.quant_dtype is not None: - if quant_config.is_per_tensor: - assert a1q_scale.numel() == 1 - elif quant_config.is_per_act_token: - assert a1q_scale.numel() == a1.numel() - assert a1q_scale.shape == a1.shape - else: - assert a1q_scale.numel() == a1.shape[0] * cdiv(a1.shape[1], quant_config.block_shape[1]) - assert a1q_scale.shape == (a1.shape[0], cdiv(a1.shape[1], quant_config.block_shape[1])) + _validate_scale_shape( + a1q, + a1q_scale, + quant_config.per_act_token_quant, + quant_config.block_shape + ) if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -140,22 +137,9 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] - # pad out scales if needed. TODO (bnell): do for non-scalar scales? - if False and (scalar_scales or quant_config.is_per_tensor): - #print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}") - a1q_scale = a1q_scale.repeat(1, 4 * torch.float32.itemsize) - else: - #a1q_scale = torch.repeat_interleave(a1q_scale, round_up(a1q_scale.shape[1], 16), dim=1) - #a1q_scale = torch.nn.functional.pad(a1q_scale, pad=(0, 16-a1q_scale.shape[1]), mode='replicate') - pass - if not quant_config.is_grouped: a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) - #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}" - - #print(f"FINAL SCALE SHAPE {a1q_scale.shape}") - assert a1q_scale is None or a1q_scale.ndim == 2, \ f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" @@ -172,7 +156,6 @@ def prepare( ) num_dp = self.world_size // self.dp_size - #print(f"EXPERT_X {(num_local_experts, self.max_num_tokens * num_dp, hidden_dim)}, {a1q.dtype}, {device}") expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), dtype=a1q.dtype, @@ -184,24 +167,24 @@ def prepare( float32_size = torch.float32.itemsize if quant_config.is_per_act_token: + token_dim = expert_x.size(1) final_dim = expert_x.size(2) assert final_dim % 4 == 0 #? elif quant_config.is_per_tensor: + token_dim = 1 final_dim = 4 else: num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) final_dim = round_up(num_blocks, 4) + token_dim = expert_x.size(1) expert_x_scale_shape = ( num_local_experts, - expert_x.size(1), + token_dim, final_dim, ) - #print(f"EXPERT_X_SCALE {expert_x_scale_shape}") - - # empty? - expert_x_scale = torch.zeros( + expert_x_scale = torch.empty( expert_x_scale_shape, dtype=torch.float32, device=expert_x.device, @@ -211,8 +194,6 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - #print(f"DISPATCH X={expert_x.shape}, X_SCALE={expert_x_scale.shape}, A={a1q.shape}, A_SCALE={a1q_scale.shape}, TOPK={topk_ids}") - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -222,7 +203,6 @@ def prepare( indices=topk_ids, bound_m=bound_m, ) - #print(f"DISPATCH DONE {device}") if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] @@ -243,8 +223,8 @@ def finalize( # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None - assert topk_ids.size(0) == num_tokens, ( - f"{topk_ids.size(0)} == {num_tokens}") + #assert topk_ids.size(0) == num_tokens, ( + # f"{topk_ids.size(0)} == {num_tokens}") assert output.size(0) <= self.max_num_tokens, ( f"{output.size(0)} <= {self.max_num_tokens}") assert output.size(1) == fused_expert_output.size(-1) @@ -253,10 +233,8 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - #print(f"COMBINE {output.device}") self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) - #print(f"COMBINE DONE {output.device}") diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 52346f797440..37f9581cd2e8 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -111,3 +111,22 @@ def maybe_fix_scales(scales: Optional[torch.Tensor], scales = scales.view(num_experts, -1, scales.size(-1)) return scales + + +def _validate_scale_shape( + a: torch.Tensor, + a_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]], +) -> None: + if a_scale is None: + return + + if not per_act_token_quant and block_shape is None: + assert a_scale.numel() == 1, f"{a_scale.shape}" + elif per_act_token_quant: + assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( + f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") + else: + expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) + assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 3f0a1c2634d6..3288d2bddb75 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -864,7 +864,7 @@ def select_gemm_impl( experts = CutlassExpertsFp8( num_experts, - None, #moe.in_dtype, + moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, use_batched_format=use_batched_format, From 185f09065b1ecdb8619d6c126aed0aaf062055f0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 25 Jun 2025 20:14:21 +0000 Subject: [PATCH 47/77] both models work Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 33 +++- tests/kernels/moe/test_pplx_moe.py | 155 +++++++++++++----- .../batched_triton_or_deep_gemm_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 64 ++++++-- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 5 +- .../layers/fused_moe/modular_kernel.py | 4 + .../layers/fused_moe/pplx_prepare_finalize.py | 12 +- .../model_executor/layers/quantization/fp8.py | 2 + 9 files changed, 213 insertions(+), 66 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 467eddadf2b8..e37050f841d0 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -207,6 +207,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) +# @pytest.mark.parametrize("m", [6, 16, 199, 200, 256]) +# @pytest.mark.parametrize("n", [2816//2]) +# @pytest.mark.parametrize("k", [2048]) +# @pytest.mark.parametrize("e", [32]) +# @pytest.mark.parametrize("topk", [6]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +# @pytest.mark.parametrize("per_act_token_quant", [False]) +# @pytest.mark.parametrize("block_shape", [None]) @pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -214,6 +222,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) +@pytest.mark.parametrize("input_scales", [False]) def test_fused_moe_batched_experts( m: int, n: int, @@ -223,6 +232,7 @@ def test_fused_moe_batched_experts( dtype: torch.dtype, per_act_token_quant: bool, block_shape: Optional[list[int]], + input_scales: bool, ): current_platform.seed_everything(7) @@ -257,9 +267,17 @@ def test_fused_moe_batched_experts( per_act_token_quant=per_act_token_quant, ) + if input_scales and quant_dtype is not None: + a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32) + a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32) + else: + a1_scale = None + a2_scale = None + with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - batched_output = batched_moe( + + baseline_output = torch_experts( a, w1, w2, @@ -267,12 +285,14 @@ def test_fused_moe_batched_experts( topk_ids, w1_scale=w1_s, w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) - baseline_output = torch_experts( + batched_output = naive_batched_moe( a, w1, w2, @@ -280,11 +300,14 @@ def test_fused_moe_batched_experts( topk_ids, w1_scale=w1_s, w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, - block_shape=block_shape) + block_shape=block_shape, + ) - triton_output = triton_moe( + triton_output = batched_moe( a, w1, w2, @@ -292,6 +315,8 @@ def test_fused_moe_batched_experts( topk_ids, w1_scale=w1_s, w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 23dd8eeb3297..98268588a2be 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -41,9 +41,10 @@ reason="Requires PPLX kernels", ) -PPLX_PREPARE_COMBOS = [ +PPLX_COMBOS = [ # TODO: figure out why this fails #(1, 128, 128), + (2, 128, 512), (3, 1024, 2048), (4, 128, 128), @@ -51,16 +52,13 @@ (45, 512, 2048), (64, 1024, 512), (222, 2048, 1024), -] + (256, 1408, 2048), -PPLX_MOE_COMBOS = [ -# (1, 128, 128), - (2, 128, 512), - (3, 1024, 2048), - (32, 128, 1024), - (45, 512, 2048), - (64, 1024, 1024), - (222, 1024, 2048), + #(6, 1408, 2048), + #(16, 1408, 2048), + #(199, 1408, 2048), + #(200, 1408, 2048), + #(256, 1408, 2048), ] NUM_EXPERTS = [8, 64] @@ -281,10 +279,17 @@ def pplx_prepare_finalize( device=device, ) + if quant_dtype is not None and not per_act_token_quant and block_shape is None: + a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + else: + a1_scale = None + a2_scale = None + b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( a_chunk, - None, - None, + a1_scale, + a2_scale, chunk_topk_weight, chunk_topk_ids, num_experts, @@ -365,7 +370,7 @@ def _pplx_prepare_finalize( # TODO (bnell): this test point does not work for M==1 due to how the test # is written, not due to limitations of the pplx kernels. -@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) +@pytest.mark.parametrize("mnk", PPLX_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @@ -424,7 +429,9 @@ def pplx_moe( topk_ids: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - qtype: Optional[torch.dtype] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, per_act_token_quant=False, block_shape: Optional[list[int]] = None, use_compile: bool = False, @@ -443,7 +450,7 @@ def pplx_moe( max_num_tokens, hidden_dim, a.dtype, - qtype, + quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) @@ -479,7 +486,7 @@ def pplx_moe( experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, world_size=world_size, dp_size=dp_size, - use_fp8_w8a8=qtype == torch.float8_e4m3fn, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, block_shape=block_shape) fused_experts = FusedMoEModularKernel( @@ -497,12 +504,8 @@ def pplx_moe( w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) if w1_scale is not None: - if not per_act_token_quant: - w1_scale_chunk = w1_scale - w2_scale_chunk = w2_scale - else: - w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) - w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) + w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) + w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) else: w1_scale_chunk = None w2_scale_chunk = None @@ -527,6 +530,8 @@ def pplx_moe( chunk_topk_ids, w1_scale=w1_scale_chunk, w2_scale=w2_scale_chunk, + a1_scale=a1_scale, + a2_scale=a2_scale, global_num_experts=num_experts) if use_cudagraphs: @@ -541,6 +546,8 @@ def pplx_moe( chunk_topk_ids, w1_scale=w1_scale_chunk, w2_scale=w2_scale_chunk, + a1_scale=a1_scale, + a2_scale=a2_scale, global_num_experts=num_experts) torch.cuda.synchronize() @@ -563,7 +570,7 @@ def _pplx_moe( topk: int, w1_s: Optional[torch.Tensor] = None, w2_s: Optional[torch.Tensor] = None, - qtype: Optional[torch.dtype] = None, + quant_dtype: Optional[torch.dtype] = None, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, use_internode: bool = False, @@ -585,48 +592,112 @@ def _pplx_moe( moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) device = torch.device("cuda", pgi.rank) + rank = pgi.rank + world_size = pgi.world_size a = a.to(device) w1 = w1.to(device) w2 = w2.to(device) w1_s = w1_s.to(device) if w1_s is not None else None w2_s = w2_s.to(device) if w2_s is not None else None + if quant_dtype is not None and not per_act_token_quant and block_shape is None: + a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + else: + a1_scale = None + a2_scale = None + with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - quant_dtype=qtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape) - - pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, - a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, - qtype, per_act_token_quant, block_shape) + + if False: + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + topk_weight_chunk = chunk_by_rank(topk_weight, rank, world_size).to(device) + topk_ids_chunk = chunk_by_rank(topk_ids, rank, world_size).to(device) + w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) + w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + + if w1_s is not None: + w1_s_chunk = chunk_by_rank(w1_s, rank, world_size).to(device) + w2_s_chunk = chunk_by_rank(w2_s, rank, world_size).to(device) + else: + w1_s_chunk = None + w2_s_chunk = None + else: + a_chunk = a + topk_weight_chunk = topk_weight + topk_ids_chunk = topk_ids + w1_chunk = w1 + w2_chunk = w2 + w1_s_chunk = w1_s + w2_s_chunk = w2_s + + torch_output = torch_experts( + a_chunk, + w1_chunk, + w2_chunk, + topk_weight_chunk, + topk_ids_chunk, + w1_scale=w1_s_chunk, + w2_scale=w2_s_chunk, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + batched_output = naive_batched_moe( + a_chunk, + w1_chunk, + w2_chunk, + topk_weight_chunk, + topk_ids_chunk, + w1_scale=w1_s_chunk, + w2_scale=w2_s_chunk, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + pplx_output = pplx_moe( + group_name, + rank, + world_size, + dp_size, + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape) # all reduce on pplx? #torch.distributed.all_reduce(pplx_output) - batched_output = naive_batched_moe(a, w1, w2, topk_weight, - topk_ids, w1_s, w2_s, qtype, per_act_token_quant, block_shape) - chunked_torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2) - #torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) + torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) if use_internode: nvshmem_finalize() -@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) +@pytest.mark.parametrize("mnk", PPLX_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +#@pytest.mark.parametrize("e", [32]) +#@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 65bd4f49b57f..062f204798ef 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -8,7 +8,7 @@ BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + BatchedTritonExperts, NaiveBatchedExperts) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 713f22211b68..efe37e76b5c5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -381,6 +381,13 @@ def invoke_moe_batched_triton_kernel( grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) + # ????? + A_scale = maybe_fix_scales(A_scale, expert_num_tokens.shape[0]) + + if B_scale is not None and B_scale.ndim == 1: + assert B_scale.numel() == expert_num_tokens.shape[0] + B_scale = B_scale.view(-1, 1, 1) + assert A_scale is None or A_scale.ndim == 3, ( f"{0 if A_scale is None else A_scale.shape}") assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, ( @@ -543,6 +550,9 @@ def prepare( first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts + a1_scale = maybe_fix_2d_scales(a1_scale) + a2_scale = maybe_fix_2d_scales(a2_scale) + for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() rows = torch.count_nonzero(topks.flatten()) @@ -553,21 +563,25 @@ def prepare( rhs = a1[:topks.numel()][topks] if quant_config.quant_dtype is not None: if a1_scale is not None: - assert False, "NYI" - rhs_a1_scale = a1_scale[:topks.numel()][topks] + if quant_config.is_per_act_token: + rhs_a1_scale = a1_scale[:topks.numel()][topks] + else: + rhs_a1_scale = a1_scale else: rhs_a1_scale = None - b_a1[idx, :rows, :], b_s = (moe_kernel_quantize_input( + b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input( rhs, rhs_a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, - )) - if quant_config.is_per_tensor: - b_a1_scale[idx] = b_s - else: + ) + if quant_config.is_per_act_token: + #print(f"B_S1 {b_s.shape}") b_a1_scale[idx, :rows] = b_s[:rows] + else: + #print(f"B_S2 {b_s.shape}") + b_a1_scale[idx, :b_s.shape[0]] = b_s else: b_a1[idx, :rows, :] = rhs @@ -704,14 +718,12 @@ def apply( assert num_local_experts == w1.size(0), ( f"{num_local_experts} == {w1.size(0)}") - assert a2_scale is None, "NYI" - N = w1.size(1) // 2 f32 = torch.float32 for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor - if (torch.compiler.is_compiling() + if (True or torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): num = hidden_states.shape[1] else: @@ -740,7 +752,7 @@ def apply( def maybe_fix_scales(scales: Optional[torch.Tensor], num_experts: int) -> Optional[torch.Tensor]: - if scales is not None: + if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) scales = torch.repeat_interleave( @@ -754,6 +766,15 @@ def maybe_fix_scales(scales: Optional[torch.Tensor], num_experts: int) -> Option return scales +def maybe_fix_2d_scales(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + scales = scales.view(1, 1) + else: + scales = scales.view(-1, scales.size(-1)) + return scales + + def batched_moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -970,7 +991,6 @@ def apply( intermediate_cache1.fill_(0) a1q_scale = maybe_fix_scales(a1q_scale, E) - #a2_scale = maybe_fix_scales(a2_scale, E) # MM1 invoke_moe_batched_triton_kernel( @@ -1007,9 +1027,23 @@ def apply( intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N)) - qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, max_num_tokens, E, N, expert_num_tokens, - self.quant_dtype, self.per_act_token_quant, self.block_shape) + if True: + qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( + intermediate_cache2, a2_scale, max_num_tokens, E, N, expert_num_tokens, + self.quant_dtype, self.per_act_token_quant, self.block_shape) + else: + ic2_hidden_size = intermediate_cache2.size(-1) + intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size) + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + A=intermediate_cache2, + A_scale=a2_scale, + quant_dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape) + + qintermediate_cache2 = qintermediate_cache2.view( + (E, -1, ic2_hidden_size)) invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 75712b8e3a4d..8dd26f6c2d0d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -907,7 +907,7 @@ def fused_topk( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +#@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8501abd9e609..a702adeda29e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, + get_world_group, get_tensor_model_parallel_world_size, get_world_group, tensor_model_parallel_all_reduce) @@ -650,8 +651,8 @@ def __init__( tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) - dp_size_ = (dp_size - if dp_size is not None else get_dp_group().world_size) + dp_size_ = (dp_size if dp_size is not None else + get_dp_group().world_size) world_size_ = get_world_group().world_size vllm_config = get_current_vllm_config() diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2ffb4d328eca..1dcac9aa3a45 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -454,6 +454,10 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts + #def maybe_shape(x): + # return x.shape if x is not None else None + #print(f"PROBLEM topk={topk_ids.size(1)} E={global_num_experts}/{local_num_experts}, M={a1.shape[0]}, N={w1.size(1)}, K={a1.shape[1]}, {w1_scale.shape}, {w2_scale.shape}, {maybe_shape(a1_scale)}, {maybe_shape(a2_scale)} pt={self.fused_experts.per_act_token_quant} bs={self.fused_experts.block_shape}") + (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index a93afc82a00f..69d84e9fdbe7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -171,7 +171,7 @@ def prepare( final_dim = expert_x.size(2) assert final_dim % 4 == 0 #? elif quant_config.is_per_tensor: - token_dim = 1 + token_dim = expert_x.size(1) #XXXXXXXXXXXXXXXXXX final_dim = 4 else: num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) @@ -184,6 +184,8 @@ def prepare( final_dim, ) + # XXXX make sure shape matches up with pplx hidden bytes + expert_x_scale = torch.empty( expert_x_scale_shape, dtype=torch.float32, @@ -194,6 +196,8 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None + #print(f"DISPATCH START") + self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -204,6 +208,8 @@ def prepare( bound_m=bound_m, ) + #print(f"DISPATCH END") + if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 @@ -233,8 +239,12 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) + #print(f"COMBINE START") + self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) + + #print(f"COMBINE END") diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0295f5e2a1c8..612eb99a124f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -905,6 +905,8 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) else: + #print(f"A1_SCALE = {layer.w13_input_scale}") + #print(f"A2_SCALE = {layer.w2_input_scale}") return self.fused_experts( hidden_states=x, w1=layer.w13_weight, From 946a950ff9d97d84bad72fcab3faebfcd1cffdff Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 26 Jun 2025 12:22:32 +0000 Subject: [PATCH 48/77] cleanup Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 22 ++++++++----------- .../layers/fused_moe/fused_batched_moe.py | 1 - .../layers/fused_moe/fused_moe.py | 2 +- .../layers/fused_moe/modular_kernel.py | 4 ---- .../layers/fused_moe/pplx_prepare_finalize.py | 11 +++------- 5 files changed, 13 insertions(+), 27 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 98268588a2be..615847b066bf 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -20,7 +20,7 @@ from tests.kernels.moe.utils import make_test_weights, naive_batched_moe from tests.kernels.utils import torch_experts -from tests.kernels.quant_utils import batched_dequant +from tests.kernels.quant_utils import dequant from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -42,7 +42,7 @@ ) PPLX_COMBOS = [ - # TODO: figure out why this fails + # TODO: figure out why this fails, seems to be test problem #(1, 128, 128), (2, 128, 512), @@ -53,12 +53,6 @@ (64, 1024, 512), (222, 2048, 1024), (256, 1408, 2048), - - #(6, 1408, 2048), - #(16, 1408, 2048), - #(199, 1408, 2048), - #(200, 1408, 2048), - #(256, 1408, 2048), ] NUM_EXPERTS = [8, 64] @@ -268,12 +262,16 @@ def pplx_prepare_finalize( dp_size, ) + assert a.shape[0] == topk_ids.shape[0] + a_chunk = chunk_by_rank(a, rank, world_size).to(device) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + assert a_chunk.shape[0] == chunk_topk_ids.shape[0] + out = torch.full( - (max_num_tokens, hidden_dim), + a_chunk.shape, torch.nan, dtype=a.dtype, device=device, @@ -306,7 +304,7 @@ def pplx_prepare_finalize( #print(f"B_A_SCALE = {b_a.shape}, {b_a_scale.shape if b_a_scale is not None else None}, {per_act_token_quant} {block_shape}, {a_chunk.shape}") # TOOD: shouldn't need batched_dequant - b_a = dummy_work(batched_dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) + b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) prepare_finalize.finalize( out, @@ -368,8 +366,6 @@ def _pplx_prepare_finalize( nvshmem_finalize() -# TODO (bnell): this test point does not work for M==1 due to how the test -# is written, not due to limitations of the pplx kernels. @pytest.mark.parametrize("mnk", PPLX_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @@ -444,7 +440,7 @@ def pplx_moe( hidden_dim = a.shape[1] num_experts = w1.shape[0] topk = topk_ids.shape[1] - max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64) + max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16) hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( max_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index efe37e76b5c5..dd4156233e45 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -381,7 +381,6 @@ def invoke_moe_batched_triton_kernel( grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) - # ????? A_scale = maybe_fix_scales(A_scale, expert_num_tokens.shape[0]) if B_scale is not None and B_scale.ndim == 1: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8dd26f6c2d0d..75712b8e3a4d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -907,7 +907,7 @@ def fused_topk( # This is used by the Deepseek-V2 and Deepseek-V3 model -#@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 1dcac9aa3a45..2ffb4d328eca 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -454,10 +454,6 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - #def maybe_shape(x): - # return x.shape if x is not None else None - #print(f"PROBLEM topk={topk_ids.size(1)} E={global_num_experts}/{local_num_experts}, M={a1.shape[0]}, N={w1.size(1)}, K={a1.shape[1]}, {w1_scale.shape}, {w2_scale.shape}, {maybe_shape(a1_scale)}, {maybe_shape(a2_scale)} pt={self.fused_experts.per_act_token_quant} bs={self.fused_experts.block_shape}") - (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 69d84e9fdbe7..e7739a01cd62 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -196,8 +196,6 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - #print(f"DISPATCH START") - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -208,8 +206,6 @@ def prepare( bound_m=bound_m, ) - #print(f"DISPATCH END") - if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 @@ -231,6 +227,9 @@ def finalize( #assert topk_ids.size(0) == num_tokens, ( # f"{topk_ids.size(0)} == {num_tokens}") + assert topk_ids.size() == topk_weights.size(), ( + f"{topk_ids.size()} == {topk_weights.size()}" + ) assert output.size(0) <= self.max_num_tokens, ( f"{output.size(0)} <= {self.max_num_tokens}") assert output.size(1) == fused_expert_output.size(-1) @@ -239,12 +238,8 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - #print(f"COMBINE START") - self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) - - #print(f"COMBINE END") From d8a27230e82e00fad6f932671fb4e37cd0bd894c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 26 Jun 2025 21:38:32 +0000 Subject: [PATCH 49/77] lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a702adeda29e..8501abd9e609 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,7 +13,6 @@ import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, - get_world_group, get_tensor_model_parallel_world_size, get_world_group, tensor_model_parallel_all_reduce) @@ -651,8 +650,8 @@ def __init__( tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) - dp_size_ = (dp_size if dp_size is not None else - get_dp_group().world_size) + dp_size_ = (dp_size + if dp_size is not None else get_dp_group().world_size) world_size_ = get_world_group().world_size vllm_config = get_current_vllm_config() From e40e9c089f6eda740e271383e9693393b2b31555 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 27 Jun 2025 02:14:45 +0000 Subject: [PATCH 50/77] fix test Signed-off-by: Bill Nell --- .../fused_moe/deepep_ht_prepare_finalize.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index d8ddec9554f0..87c304ee1a95 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -136,20 +136,7 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) - # Check if there is a block_shape / or if we can infer the quantization - # schemes from the scales. - per_token_quant = None - if all([ - x is None - for x in [quant_config.block_shape, a1_scale, a2_scale] - ]) and quant_config.quant_dtype is not None: - # Quantization required despite none of the inputs suggesting - # quantization. Fallback to per_token_dynamic quant. - per_token_quant = True - else: - per_token_quant = False - - if per_token_quant: + if quant_config.per_act_token_quant: a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, From 79e8d6bba50e59c4dcdd7e13a658cf7b44d16af5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 27 Jun 2025 03:10:17 +0000 Subject: [PATCH 51/77] fixes Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/config.py | 10 ++++++++-- .../compressed_tensors/compressed_tensors_moe.py | 10 +++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index c5aca64b8f2b..fd41abadeb32 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -90,9 +90,14 @@ def is_grouped(self) -> bool: def is_per_tensor(self) -> bool: return not self.per_act_token_quant and self.block_shape is None - def scale_shape(self, max_tokens: int, hidden_dim: int) -> Optional[tuple[int, int]]: + def scale_shape( + self, + max_tokens: int, + hidden_dim: int, + ) -> Optional[tuple[int, int]]: if self.is_quantized: if self.is_grouped: + assert self.block_shape is not None _, block_k = self.block_shape k_tiles = cdiv(hidden_dim, block_k) return (max_tokens, k_tiles) @@ -107,10 +112,11 @@ def batched_scale_shape( self, num_experts: int, max_tokens: int, - hidden_dim: int + hidden_dim: int, ) -> Optional[tuple[int, int, int]]: if self.is_quantized: scale_shape = self.scale_shape(max_tokens, hidden_dim) + assert scale_shape is not None return (num_experts, *scale_shape) else: return None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 3288d2bddb75..42ee6443d654 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -572,13 +572,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: self.fused_experts_func = fused_experts - def select_gemm_impl(self, prepare_finalize): + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) assert not self.rocm_aiter_moe_enabled and not self.use_marlin - logger.debug("BatchedTritonExperts(%s)", self.__classname__.__name__) + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) use_batched_format = (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts) @@ -860,7 +864,7 @@ def select_gemm_impl( num_experts = (moe.num_local_experts if use_batched_format else moe.num_experts) - logger.debug("CutlassExpertsFp8(%s)", self.__classname__.__name__) + logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) experts = CutlassExpertsFp8( num_experts, From c02faca5189f201be725f10a7495d97ebdc5d3c9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 27 Jun 2025 04:29:28 +0000 Subject: [PATCH 52/77] fix pplx tests, fix indices type assert Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 133 ++++++------ tests/kernels/utils.py | 7 +- .../layers/fused_moe/fused_batched_moe.py | 196 ++++++++---------- vllm/model_executor/layers/fused_moe/layer.py | 6 +- 4 files changed, 156 insertions(+), 186 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 615847b066bf..cd45e0b2b50d 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -19,23 +19,21 @@ has_pplx = False from tests.kernels.moe.utils import make_test_weights, naive_batched_moe -from tests.kernels.utils import torch_experts from tests.kernels.quant_utils import dequant +from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) + BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) -from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform from vllm.utils import round_up from .parallel_utils import ProcessGroupInfo, parallel_launch - requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", @@ -44,7 +42,6 @@ PPLX_COMBOS = [ # TODO: figure out why this fails, seems to be test problem #(1, 128, 128), - (2, 128, 512), (3, 1024, 2048), (4, 128, 128), @@ -173,9 +170,11 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) # only for baseline + baseline_output = torch_experts(a, w1, w2, topk_weight, + topk_ids) # only for baseline torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) - batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this + batched_output = naive_batched_moe( + a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this torch.testing.assert_close(baseline_output, torch_output, @@ -277,7 +276,8 @@ def pplx_prepare_finalize( device=device, ) - if quant_dtype is not None and not per_act_token_quant and block_shape is None: + if (quant_dtype is not None and not per_act_token_quant + and block_shape is None): a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -301,10 +301,8 @@ def pplx_prepare_finalize( ), ) - #print(f"B_A_SCALE = {b_a.shape}, {b_a_scale.shape if b_a_scale is not None else None}, {per_act_token_quant} {block_shape}, {a_chunk.shape}") - # TOOD: shouldn't need batched_dequant - - b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) + b_a = dummy_work( + dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) prepare_finalize.finalize( out, @@ -358,7 +356,8 @@ def _pplx_prepare_finalize( num_experts, quant_dtype, block_shape, per_act_token_quant, group_name) - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pgi.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pgi.device) torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) @@ -408,9 +407,9 @@ def test_pplx_prepare_finalize( a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 score = torch.randn((m, e), device=device, dtype=act_dtype) - parallel_launch(world_size, _pplx_prepare_finalize, dp_size, - a, score, topk, e, quant_dtype, block_shape, - per_act_token_quant, use_internode) + parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, + topk, e, quant_dtype, block_shape, per_act_token_quant, + use_internode) def pplx_moe( @@ -479,11 +478,12 @@ def pplx_moe( dp_size, ) - experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + block_shape=block_shape) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -596,7 +596,8 @@ def _pplx_moe( w1_s = w1_s.to(device) if w1_s is not None else None w2_s = w2_s.to(device) if w2_s is not None else None - if quant_dtype is not None and not per_act_token_quant and block_shape is None: + if (quant_dtype is not None and not per_act_token_quant + and block_shape is None): a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -606,36 +607,14 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - if False: - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - topk_weight_chunk = chunk_by_rank(topk_weight, rank, world_size).to(device) - topk_ids_chunk = chunk_by_rank(topk_ids, rank, world_size).to(device) - w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) - w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) - - if w1_s is not None: - w1_s_chunk = chunk_by_rank(w1_s, rank, world_size).to(device) - w2_s_chunk = chunk_by_rank(w2_s, rank, world_size).to(device) - else: - w1_s_chunk = None - w2_s_chunk = None - else: - a_chunk = a - topk_weight_chunk = topk_weight - topk_ids_chunk = topk_ids - w1_chunk = w1 - w2_chunk = w2 - w1_s_chunk = w1_s - w2_s_chunk = w2_s - torch_output = torch_experts( - a_chunk, - w1_chunk, - w2_chunk, - topk_weight_chunk, - topk_ids_chunk, - w1_scale=w1_s_chunk, - w2_scale=w2_s_chunk, + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, a1_scale=a1_scale, a2_scale=a2_scale, quant_dtype=quant_dtype, @@ -644,25 +623,6 @@ def _pplx_moe( ) batched_output = naive_batched_moe( - a_chunk, - w1_chunk, - w2_chunk, - topk_weight_chunk, - topk_ids_chunk, - w1_scale=w1_s_chunk, - w2_scale=w2_s_chunk, - a1_scale=a1_scale, - a2_scale=a2_scale, - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) - - pplx_output = pplx_moe( - group_name, - rank, - world_size, - dp_size, a, w1, w2, @@ -674,16 +634,39 @@ def _pplx_moe( a2_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, - block_shape=block_shape) + block_shape=block_shape, + ) - # all reduce on pplx? - #torch.distributed.all_reduce(pplx_output) + pplx_output = pplx_moe(group_name, + rank, + world_size, + dp_size, + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape) chunked_torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - torch.testing.assert_close(pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2) - torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) + tol = 6e-2 if quant_dtype is not None else 3e-2 + + torch.testing.assert_close(pplx_output, + chunked_torch_output, + atol=tol, + rtol=tol) + torch.testing.assert_close(batched_output, + torch_output, + atol=3e-2, + rtol=3e-2) if use_internode: nvshmem_finalize() diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 84cf87d71d88..6b8ac92e60be 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1094,6 +1094,8 @@ def torch_experts( if expert_map is not None: topk_ids = expert_map[topk_ids] + f32 = torch.float32 + for i in range(num_experts): mask = topk_ids == i if mask.sum(): @@ -1117,7 +1119,6 @@ def torch_experts( else: assert (a_scale is not None and w1_scale is not None and w2_scale is not None) - f32 = torch.float32 scales = a_scale if a_scale.numel() == 1 else a_scale[mask] tmp1 = a[mask].to(f32) * scales w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) @@ -1126,8 +1127,8 @@ def torch_experts( w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M, -1, w2.shape[1]).to(f32) * + topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) def torch_moe(a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index dd4156233e45..7e3862986c0c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -13,7 +13,8 @@ get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + group_broadcast) @triton.jit @@ -72,7 +73,7 @@ def moe_mmk( elif per_act_token_quant: # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + offs_m * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None] + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) @@ -107,7 +108,8 @@ def moe_mmk( other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] else: # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) @@ -228,49 +230,49 @@ def expert_triton_kernel( @triton.jit def batched_triton_kernel( - a_ptr, # [E, max_num_tokens, K] - b_ptr, # [E, K, N] - c_ptr, # [E, max_num_tokens, N] - expert_num_tokens, # [E] - compute_type: tl.constexpr, - # Dimensions - max_num_tokens, - K, - N, - # Quantization data - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ae, - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_ce, - stride_cm, - stride_cn, - stride_ase, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n: tl.constexpr, - group_k: tl.constexpr, - # Quantization schemes - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - per_act_token_quant: tl.constexpr, - # Kernel config - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ): expert_id = tl.program_id(axis=0) e_num_tokens = tl.load(expert_num_tokens + expert_id) @@ -306,9 +308,7 @@ def batched_triton_kernel( b_scale_ptr = b_scale_ptr + expert_id * stride_bse # block-wise - if group_k > 0 and group_n > 0: - a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm - elif per_act_token_quant: + if group_k > 0 and group_n > 0 or per_act_token_quant: a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm expert_triton_kernel( @@ -535,9 +535,8 @@ def prepare( device=a1.device) if quant_config.is_quantized: - scale_shape = quant_config.batched_scale_shape(num_local_experts, - self.max_num_tokens, - hidden_dim) + scale_shape = quant_config.batched_scale_shape( + num_local_experts, self.max_num_tokens, hidden_dim) b_a1_scale = torch.empty(scale_shape, dtype=torch.float32, @@ -685,7 +684,8 @@ def workspace_shapes( def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: assert self.quant_config.is_quantized f32 = torch.float32 - if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: + if (self.quant_config.is_per_act_token + or self.quant_config.is_per_tensor): return t.to(f32) * scale else: return t.to(f32) * group_broadcast(scale, t.shape) @@ -718,11 +718,10 @@ def apply( f"{num_local_experts} == {w1.size(0)}") N = w1.size(1) // 2 - f32 = torch.float32 for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor - if (True or torch.compiler.is_compiling() + if (torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): num = hidden_states.shape[1] else: @@ -734,11 +733,13 @@ def apply( tmp = _resize_cache(workspace2, (num, N)) if self.quant_config.is_quantized: - input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) + input = self.dequant(hidden_states[expert, :, :], + a1q_scale[expert]) w1_dq = self.dequant(w1[expert], w1_scale[expert]) input = input[:num] @ w1_dq.transpose(0, 1) else: - input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + input = hidden_states[expert, :num, :] @ w1[expert].transpose( + 0, 1) self.activation(activation, tmp, input.to(tmp.dtype)) @@ -750,22 +751,21 @@ def apply( output[expert, :num, :] = tmp @ w2_dq.transpose(0, 1).to(tmp.dtype) -def maybe_fix_scales(scales: Optional[torch.Tensor], num_experts: int) -> Optional[torch.Tensor]: +def maybe_fix_scales(scales: Optional[torch.Tensor], + num_experts: int) -> Optional[torch.Tensor]: if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) - scales = torch.repeat_interleave( - scales, - num_experts, - dim=0 - ).view(num_experts, 1, 1) + scales = torch.repeat_interleave(scales, num_experts, + dim=0).view(num_experts, 1, 1) else: scales = scales.view(num_experts, -1, scales.size(-1)) return scales -def maybe_fix_2d_scales(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +def maybe_fix_2d_scales( + scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: scales = scales.view(1, 1) @@ -785,13 +785,14 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - # TODO: fix this + # TODO (bnell): fix this if (True or torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) - assert A_scale is None or A_scale.ndim <= 2, f"{A_scale.shape if A_scale is not None else None}" + assert A_scale is None or A_scale.ndim <= 2, ( + f"{A_scale.shape if A_scale is not None else None}") A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, hidden_dim), A_scale, qtype, per_act_token_quant, @@ -1010,8 +1011,7 @@ def apply( intermediate_cache2.fill_(0) - # TODO: would be nice to use expert_num_tokens here to reduce - # garbage compute + # TODO: use triton utility from batched deep gemm. if False: tmp = torch.empty_like(intermediate_cache2[0]) for e in range(E): @@ -1021,40 +1021,26 @@ def apply( intermediate_cache1[e, :num_tokens]) intermediate_cache2[e, :num_tokens] = tmp[:num_tokens] else: - self.activation( - activation, - intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) - - if True: - qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, max_num_tokens, E, N, expert_num_tokens, - self.quant_dtype, self.per_act_token_quant, self.block_shape) - else: - ic2_hidden_size = intermediate_cache2.size(-1) - intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size) - - qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - A=intermediate_cache2, - A_scale=a2_scale, - quant_dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) - - qintermediate_cache2 = qintermediate_cache2.view( - (E, -1, ic2_hidden_size)) - - invoke_moe_batched_triton_kernel(A=qintermediate_cache2, - B=w2, - C=output, - expert_num_tokens=expert_num_tokens, - compute_type=compute_type, - A_scale=a2q_scale, - B_scale=w2_scale, - B_zp=w2_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - config=config, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) + self.activation(activation, intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) + + qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( + intermediate_cache2, a2_scale, max_num_tokens, E, N, + expert_num_tokens, self.quant_dtype, self.per_act_token_quant, + self.block_shape) + + invoke_moe_batched_triton_kernel( + A=qintermediate_cache2, + B=w2, + C=output, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8501abd9e609..80583caad7f0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1186,9 +1186,9 @@ def select_experts( logical_replica_count: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Route the input hidden states to the top-k experts based on the + Route the input hidden states to the top-k experts based on the router logits. - + Returns: (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): The weights and *global physical* expert ids of the top-k experts. @@ -1299,7 +1299,7 @@ def select_experts( topk_ids = topk_ids.to(dtype=indices_type) - assert topk_ids.dtype == indices_type + assert topk_ids.dtype == indices_type or indices_type is None return topk_weights, topk_ids From 47eaa1979277b738b79d26fe41578e6865b54869 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 27 Jun 2025 17:03:56 +0000 Subject: [PATCH 53/77] fixes Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 - .../fused_moe/deepep_ll_prepare_finalize.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 89 +++++++------------ .../layers/fused_moe/fused_moe.py | 2 + .../layers/fused_moe/pplx_prepare_finalize.py | 30 +++---- vllm/model_executor/layers/fused_moe/utils.py | 18 +++- .../compressed_tensors_moe.py | 7 +- 7 files changed, 68 insertions(+), 84 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index cd45e0b2b50d..41f65d5661a6 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -675,8 +675,6 @@ def _pplx_moe( @pytest.mark.parametrize("mnk", PPLX_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -#@pytest.mark.parametrize("e", [32]) -#@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b315b4a97f04..f65f74a9444b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -7,7 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - maybe_fix_scales, moe_kernel_quantize_input) + moe_kernel_quantize_input, normalize_batched_scales_shape) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -104,7 +104,7 @@ def _do_quant( if quant_dtype is not None: assert x_scales is not None - x_scales = maybe_fix_scales(x_scales, num_experts) + x_scales = normalize_batched_scales_shape(x_scales, num_experts) return x, x_scales diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7e3862986c0c..7a59b1dc3d8c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -12,7 +12,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input) + _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, + normalize_scales_shape) from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) @@ -381,7 +382,8 @@ def invoke_moe_batched_triton_kernel( grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) - A_scale = maybe_fix_scales(A_scale, expert_num_tokens.shape[0]) + A_scale = normalize_batched_scales_shape(A_scale, + expert_num_tokens.shape[0]) if B_scale is not None and B_scale.ndim == 1: assert B_scale.numel() == expert_num_tokens.shape[0] @@ -548,8 +550,8 @@ def prepare( first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts - a1_scale = maybe_fix_2d_scales(a1_scale) - a2_scale = maybe_fix_2d_scales(a2_scale) + a1_scale = normalize_scales_shape(a1_scale) + a2_scale = normalize_scales_shape(a2_scale) for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() @@ -575,10 +577,8 @@ def prepare( quant_config.block_shape, ) if quant_config.is_per_act_token: - #print(f"B_S1 {b_s.shape}") b_a1_scale[idx, :rows] = b_s[:rows] else: - #print(f"B_S2 {b_s.shape}") b_a1_scale[idx, :b_s.shape[0]] = b_s else: b_a1[idx, :rows, :] = rhs @@ -751,29 +751,6 @@ def apply( output[expert, :num, :] = tmp @ w2_dq.transpose(0, 1).to(tmp.dtype) -def maybe_fix_scales(scales: Optional[torch.Tensor], - num_experts: int) -> Optional[torch.Tensor]: - if scales is not None and scales.ndim < 3: - if scales.numel() == 1: - scales = scales.view(1) - scales = torch.repeat_interleave(scales, num_experts, - dim=0).view(num_experts, 1, 1) - else: - scales = scales.view(num_experts, -1, scales.size(-1)) - - return scales - - -def maybe_fix_2d_scales( - scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - if scales is not None: - if scales.numel() == 1: - scales = scales.view(1, 1) - else: - scales = scales.view(-1, scales.size(-1)) - return scales - - def batched_moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -785,8 +762,7 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - # TODO (bnell): fix this - if (True or torch.compiler.is_compiling() + if (torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. @@ -798,21 +774,21 @@ def batched_moe_kernel_quantize_input( qtype, per_act_token_quant, block_shape) A_q = A_q.view(E, -1, hidden_dim) - A_q_scale = maybe_fix_scales(A_q_scale, E) + A_q_scale = normalize_batched_scales_shape(A_q_scale, E) return A_q, A_q_scale - - if qtype is not None: - assert block_shape is not None + elif qtype is None: + return A, normalize_batched_scales_shape(A_scale, E) + else: A_q = torch.empty_like(A, dtype=qtype) if per_act_token_quant: assert block_shape is None scale_shape = (E, num_tokens, 1) elif block_shape is not None: - block_n, block_k = block_shape - n_tiles = (A.shape[-1] + block_n - 1) // block_n - scale_shape = (E, num_tokens, n_tiles) + _, block_k = block_shape + k_tiles = (A.shape[-1] + block_k - 1) // block_k + scale_shape = (E, num_tokens, k_tiles) else: scale_shape = (E, 1, 1) @@ -820,19 +796,27 @@ def batched_moe_kernel_quantize_input( dtype=torch.float32, device=A.device) + num_experts = expert_num_tokens.numel() + + A_scale = normalize_batched_scales_shape(A_scale, num_experts) + for e in range(E): num_tokens = expert_num_tokens[e] if num_tokens > 0: - A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( + if A_scale is not None: + scales = A_scale[e, :min(num_tokens, A_scale.shape[1])] + else: + scales = None + A_q[e, :num_tokens], tmp_scale = moe_kernel_quantize_input( A[e, :num_tokens], - A_scale[e, :min(num_tokens, A_scale.shape[1])] - if A_scale is not None else None, qtype, - per_act_token_quant, block_shape) + scales, + qtype, + per_act_token_quant, + block_shape, + ) A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale - else: - return A, A_scale class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -990,7 +974,7 @@ def apply( if self.use_fp8_w8a8: intermediate_cache1.fill_(0) - a1q_scale = maybe_fix_scales(a1q_scale, E) + a1q_scale = normalize_batched_scales_shape(a1q_scale, E) # MM1 invoke_moe_batched_triton_kernel( @@ -1011,18 +995,9 @@ def apply( intermediate_cache2.fill_(0) - # TODO: use triton utility from batched deep gemm. - if False: - tmp = torch.empty_like(intermediate_cache2[0]) - for e in range(E): - num_tokens = expert_num_tokens[e] - if num_tokens > 0: - self.activation(activation, tmp[:num_tokens], - intermediate_cache1[e, :num_tokens]) - intermediate_cache2[e, :num_tokens] = tmp[:num_tokens] - else: - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + # TODO (bnell): use triton utility from batched deep gemm. + self.activation(activation, intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( intermediate_cache2, a2_scale, max_num_tokens, E, N, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 75712b8e3a4d..041819bb7b08 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1127,6 +1127,8 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: return torch_vllm_outplace_fused_experts +# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace +# torch ops. def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index e7739a01cd62..dc8c2f66285a 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -8,7 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input, _validate_scale_shape) + _validate_scale_shape, moe_kernel_quantize_input) from vllm.utils import cdiv, round_up @@ -120,12 +120,8 @@ def prepare( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) - _validate_scale_shape( - a1q, - a1q_scale, - quant_config.per_act_token_quant, - quant_config.block_shape - ) + _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, + quant_config.block_shape) if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -169,13 +165,15 @@ def prepare( if quant_config.is_per_act_token: token_dim = expert_x.size(1) final_dim = expert_x.size(2) - assert final_dim % 4 == 0 #? + assert final_dim % float32_size == 0 elif quant_config.is_per_tensor: - token_dim = expert_x.size(1) #XXXXXXXXXXXXXXXXXX - final_dim = 4 + token_dim = expert_x.size(1) + final_dim = float32_size else: - num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) - final_dim = round_up(num_blocks, 4) + assert quant_config.block_shape is not None + num_blocks = cdiv(expert_x.size(2), + quant_config.block_shape[1]) + final_dim = round_up(num_blocks, float32_size) token_dim = expert_x.size(1) expert_x_scale_shape = ( @@ -184,7 +182,7 @@ def prepare( final_dim, ) - # XXXX make sure shape matches up with pplx hidden bytes + # TODO (bnell): make sure shape matches up with pplx hidden bytes expert_x_scale = torch.empty( expert_x_scale_shape, @@ -220,16 +218,16 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: - num_tokens = output.size(0) # M # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None + # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on + #num_tokens = output.size(0) # M #assert topk_ids.size(0) == num_tokens, ( # f"{topk_ids.size(0)} == {num_tokens}") assert topk_ids.size() == topk_weights.size(), ( - f"{topk_ids.size()} == {topk_weights.size()}" - ) + f"{topk_ids.size()} == {topk_weights.size()}") assert output.size(0) <= self.max_num_tokens, ( f"{output.size(0)} <= {self.max_num_tokens}") assert output.size(1) == fused_expert_output.size(-1) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 37f9581cd2e8..a90cce719b48 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -99,9 +99,20 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m[idx, ...] -# TODO(bnell): better name -def maybe_fix_scales(scales: Optional[torch.Tensor], - num_experts: int) -> Optional[torch.Tensor]: +def normalize_scales_shape( + scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + scales = scales.view(1, 1) + else: + scales = scales.view(-1, scales.size(-1)) + return scales + + +def normalize_batched_scales_shape( + scales: Optional[torch.Tensor], + num_experts: int, +) -> Optional[torch.Tensor]: if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) @@ -128,5 +139,6 @@ def _validate_scale_shape( assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") else: + assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 42ee6443d654..581991d3e620 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -584,12 +584,11 @@ def select_gemm_impl( logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - use_batched_format = (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts) - - assert use_batched_format + assert (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts) max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None return BatchedTritonExperts( max_num_tokens=max_num_tokens_per_rank, From 8f3ee3aeee897a13ce477f471322a58dc6aa6023 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 27 Jun 2025 17:22:13 +0000 Subject: [PATCH 54/77] fix lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 4 ++-- tests/kernels/quant_utils.py | 7 ++++--- .../layers/fused_moe/batched_deep_gemm_moe.py | 2 -- .../layers/fused_moe/batched_triton_or_deep_gemm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/config.py | 5 ++--- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 4 ++++ 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index e37050f841d0..9e885bdc1c88 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -9,9 +9,8 @@ import triton.language as tl from tests.kernels.moe.utils import (batched_moe, - naive_batched_moe, make_quantized_test_activations, - make_test_weights, triton_moe) + make_test_weights, naive_batched_moe) from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config @@ -207,6 +206,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) + # @pytest.mark.parametrize("m", [6, 16, 199, 200, 256]) # @pytest.mark.parametrize("n", [2816//2]) # @pytest.mark.parametrize("k", [2048]) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 2970a7c9af61..6f43d1111c98 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -278,8 +278,8 @@ def dequant( def batched_dequant( - t: torch.Tensor, scale: - Optional[torch.Tensor], + t: torch.Tensor, + scale: Optional[torch.Tensor], block_shape: Optional[list[int]], per_act_token_quant: bool, out_dtype: Optional[torch.dtype] = torch.float32, @@ -288,7 +288,8 @@ def batched_dequant( assert t.shape[0] == scale.shape[0] out = torch.empty_like(t, dtype=out_dtype) for e in range(t.shape[0]): - out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, out_dtype) + out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, + out_dtype) return out return t.to(out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index b11c3855481e..6b08f32dff18 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,8 +7,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import _resize_cache -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig) from vllm.triton_utils import tl, triton logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 062f204798ef..65bd4f49b57f 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -8,7 +8,7 @@ BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) + BatchedTritonExperts) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index fd41abadeb32..083285d559de 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -16,7 +16,6 @@ QuantizationConfig) from vllm.utils import cdiv - logger = init_logger(__name__) @@ -71,8 +70,8 @@ class FusedMoEQuantConfig: # add detailed quant info for input, intermediates, weights, etc? def __post_init__(self): - assert (not self.per_act_token_quant or - self.block_shape is None), "illegal quantization" + assert (not self.per_act_token_quant + or self.block_shape is None), "illegal quantization" @property def is_quantized(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7a59b1dc3d8c..9a96c61a8a4a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -576,6 +576,7 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, ) + assert b_s is not None if quant_config.is_per_act_token: b_a1_scale[idx, :rows] = b_s[:rows] else: @@ -733,6 +734,7 @@ def apply( tmp = _resize_cache(workspace2, (num, N)) if self.quant_config.is_quantized: + assert a1q_scale is not None and w1_scale is not None input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) w1_dq = self.dequant(w1[expert], w1_scale[expert]) @@ -744,6 +746,7 @@ def apply( self.activation(activation, tmp, input.to(tmp.dtype)) if self.quant_config.is_quantized: + assert w2_scale is not None w2_dq = self.dequant(w2[expert], w2_scale[expert]) else: w2_dq = w2[expert] @@ -814,6 +817,7 @@ def batched_moe_kernel_quantize_input( per_act_token_quant, block_shape, ) + assert tmp_scale is not None A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale From 1f15b7301f046ffd977ebe92853d8bcd0212ee8f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 28 Jun 2025 02:49:59 +0000 Subject: [PATCH 55/77] fix per_act_token in pplx Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_cutlass_moe.py | 63 +-- tests/kernels/moe/test_pplx_moe.py | 435 ++++++++++++------ .../layers/fused_moe/cutlass_moe.py | 3 + .../layers/fused_moe/fused_batched_moe.py | 5 +- .../layers/fused_moe/pplx_prepare_finalize.py | 16 +- 5 files changed, 336 insertions(+), 186 deletions(-) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 184c2dd2f904..58fac246658a 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -181,35 +181,40 @@ def _pplx_moe( per_out_ch: bool, use_internode: bool, ): - if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - else: - group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") - group_name = cpu_group.group_name - - with set_current_vllm_config(vllm_config): - torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights, - topk_ids) - pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, - w2_scale, topk_weights, topk_ids, - a1_scale, out_dtype, per_act_token, - per_out_ch, group_name) - - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) - - # Uncomment if more debugging is needed - # print("PPLX OUT:", pplx_output) - # print("TORCH OUT:", torch_output) - - torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) - - if use_internode: - nvshmem_finalize() + try: + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, + backend="gloo") + group_name = cpu_group.group_name + + with set_current_vllm_config(vllm_config): + torch_output = torch_experts(a_full, w1_full, w2_full, + topk_weights, topk_ids) + pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, + w2_scale, topk_weights, topk_ids, + a1_scale, out_dtype, per_act_token, + per_out_ch, group_name) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + # Uncomment if more debugging is needed + # print("PPLX OUT:", pplx_output) + # print("TORCH OUT:", torch_output) + + torch.testing.assert_close(pplx_output, + torch_output, + atol=0.05, + rtol=0) + finally: + if use_internode: + nvshmem_finalize() @pytest.mark.parametrize("m", [2, 224]) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 41f65d5661a6..776c15a45805 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,6 +4,9 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ +import itertools +import textwrap +import traceback from typing import Optional import pytest @@ -25,7 +28,7 @@ from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + BatchedTritonExperts, NaiveBatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -53,8 +56,8 @@ ] NUM_EXPERTS = [8, 64] -EP_SIZE = [1, 4] TOP_KS = [1, 2, 6] +DTYPES = [torch.float8_e4m3fn, torch.bfloat16] vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -152,7 +155,7 @@ def torch_batched_moe( @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", DTYPES) def test_fused_moe_batched_experts( m: int, n: int, @@ -191,9 +194,30 @@ def rank_chunk(num: int, r: int, w: int) -> int: return (num // w) + (1 if r < rem else 0) -def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: - chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] +def chunk_by_rank(t: Optional[torch.Tensor], r: int, + w: int) -> Optional[torch.Tensor]: + if t is not None: + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + else: + return t + + +def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int, + w: int) -> Optional[torch.Tensor]: + if t is not None and t.numel() > 1: + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + else: + return t + + +def chunk_scales(t: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if t is not None and t.numel() > 1: + return t[start:end] + else: + return t def dummy_work(a: torch.Tensor) -> torch.Tensor: @@ -333,42 +357,49 @@ def _pplx_prepare_finalize( per_act_token_quant: bool, use_internode: bool, ): - if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - group_name = None - else: - group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") - group_name = cpu_group.group_name + try: + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + group_name = None + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, + backend="gloo") + group_name = cpu_group.group_name - topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - m, k = a.shape - - a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + m, k = a.shape - torch_output = (a_rep.view(m, topk, k) * - topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1) + a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) - pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, - num_experts, quant_dtype, block_shape, - per_act_token_quant, group_name) + torch_output = (a_rep.view(m, topk, k) * + topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum( + dim=1) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pgi.device) + pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, + topk_ids, num_experts, quant_dtype, + block_shape, per_act_token_quant, + group_name) - torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pgi.device) - if use_internode: - nvshmem_finalize() + torch.testing.assert_close(pplx_output, + torch_output, + atol=3e-2, + rtol=3e-2) + finally: + if use_internode: + nvshmem_finalize() @pytest.mark.parametrize("mnk", PPLX_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @@ -435,7 +466,6 @@ def pplx_moe( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) - device = torch.device("cuda", rank) hidden_dim = a.shape[1] num_experts = w1.shape[0] topk = topk_ids.shape[1] @@ -472,18 +502,26 @@ def pplx_moe( prepare_finalize = PplxPrepareAndFinalize( ata, - max_num_tokens, - world_size, - rank, - dp_size, - ) - - experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, world_size=world_size, + rank=rank, dp_size=dp_size, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + ) + + if False: + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + block_shape=block_shape) + else: + experts = NaiveBatchedExperts( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + block_shape=block_shape) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -491,20 +529,17 @@ def pplx_moe( ) # Note: workers with the same dp_rank must use the exact same inputs. - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) - chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + a_chunk = chunk_by_rank(a, rank, world_size) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size) # Chunking weights like this only works for batched format - w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) - w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) - - if w1_scale is not None: - w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) - w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) - else: - w1_scale_chunk = None - w2_scale_chunk = None + w1_chunk = chunk_by_rank(w1, rank, world_size) + w2_chunk = chunk_by_rank(w2, rank, world_size) + w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size) + w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size) + a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size) + a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and @@ -526,8 +561,8 @@ def pplx_moe( chunk_topk_ids, w1_scale=w1_scale_chunk, w2_scale=w2_scale_chunk, - a1_scale=a1_scale, - a2_scale=a2_scale, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, global_num_experts=num_experts) if use_cudagraphs: @@ -542,8 +577,8 @@ def pplx_moe( chunk_topk_ids, w1_scale=w1_scale_chunk, w2_scale=w2_scale_chunk, - a1_scale=a1_scale, - a2_scale=a2_scale, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, global_num_experts=num_experts) torch.cuda.synchronize() @@ -571,117 +606,207 @@ def _pplx_moe( block_shape: Optional[list[int]] = None, use_internode: bool = False, ): - if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - group_name = None + try: + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + group_name = None + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, + backend="gloo") + group_name = cpu_group.group_name + + m, k = a.shape + e, _, n = w2.shape + + moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + + device = torch.device("cuda", pgi.rank) + rank = pgi.rank + world_size = pgi.world_size + + a = a.to(device) + w1 = w1.to(device) + w2 = w2.to(device) + w1_s = w1_s.to(device) if w1_s is not None else None + w2_s = w2_s.to(device) if w2_s is not None else None + + if (quant_dtype is not None and not per_act_token_quant + and block_shape is None): + a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + else: + a1_scale = None + a2_scale = None + + with set_current_vllm_config(vllm_config), override_config(moe_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + + torch_output = torch_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + batched_output = naive_batched_moe( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + pplx_output = pplx_moe( + group_name, + rank, + world_size, + dp_size, + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + chunked_torch_output = chunk_by_rank( + torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + + #tol = 6e-2 if quant_dtype is not None else 3e-2 + tol = 3e-2 + + try: + torch.testing.assert_close(batched_output, + torch_output, + atol=3e-2, + rtol=3e-2) + + torch.testing.assert_close(pplx_output, + chunked_torch_output, + atol=tol, + rtol=tol) + except Exception: + #torch.set_printoptions(profile="full") + #print(f"PPLX {pplx_output.shape}\n{pplx_output}") + #print(f"TORCH {chunked_torch_output.shape}\n" + # f"{chunked_torch_output}") + raise + finally: + if use_internode: + nvshmem_finalize() + + +def format_result(msg, ex=None): + if ex is not None: + x = str(ex) + newx = x.strip(" \n\t")[:16] + if len(newx) < len(x): + newx = newx + " ..." + + prefix = "E\t" + print(f"{textwrap.indent(traceback.format_exc(), prefix)}") + print(f"FAILED {msg} - {newx}\n") else: - group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") - group_name = cpu_group.group_name + print(f"PASSED {msg}") - m, k = a.shape - e, _, n = w2.shape - moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) - - device = torch.device("cuda", pgi.rank) - rank = pgi.rank - world_size = pgi.world_size - a = a.to(device) - w1 = w1.to(device) - w2 = w2.to(device) - w1_s = w1_s.to(device) if w1_s is not None else None - w2_s = w2_s.to(device) if w2_s is not None else None - - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): - a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) - a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) - else: - a1_scale = None - a2_scale = None - - with set_current_vllm_config(vllm_config), override_config(moe_config): - topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - - torch_output = torch_experts( - a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - a2_scale=a2_scale, +def _pplx_moe_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool): + current_platform.seed_everything(7) + combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, + [False, True], [None, [128, 128]]) + exceptions = [] + count = 0 + for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos: + count = count + 1 + m, n, k = mnk + + if dtype == torch.float8_e4m3fn: + use_fp8_w8a8 = True + quant_dtype = dtype + else: + use_fp8_w8a8 = False + quant_dtype = None + + test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " + f"dtype={dtype}, per_act_token={per_act_token_quant}, " + f"block_shape={block_shape}") + + if not use_fp8_w8a8 and (per_act_token_quant + or block_shape is not None): + #pytest.skip("Skip quantization test for non-quantized type") + print( + f"{test_desc} - Skip quantization test for non-quantized type." + ) + continue + + if per_act_token_quant and block_shape is not None: + #pytest.skip("Skip illegal quantization combination.") + print(f"{test_desc} - Skip illegal quantization combination.") + continue + + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + + _, w1, w1_s, _, w2, w2_s = make_test_weights( + e, + n, + k, quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, block_shape=block_shape, - ) - - batched_output = naive_batched_moe( - a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - a2_scale=a2_scale, - quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, - block_shape=block_shape, ) - pplx_output = pplx_moe(group_name, - rank, - world_size, - dp_size, - a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - a2_scale=a2_scale, - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape) - - chunked_torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) - - tol = 6e-2 if quant_dtype is not None else 3e-2 - - torch.testing.assert_close(pplx_output, - chunked_torch_output, - atol=tol, - rtol=tol) - torch.testing.assert_close(batched_output, - torch_output, - atol=3e-2, - rtol=3e-2) + try: + _pplx_moe(pgi, dp_size, a, w1, w2, score, topk, w1_s, w2_s, + quant_dtype, per_act_token_quant, block_shape, + use_internode) + format_result(test_desc) + except Exception as ex: + format_result(test_desc, ex) + exceptions.append(ex) - if use_internode: - nvshmem_finalize() + if len(exceptions) > 0: + raise RuntimeError( + f"{len(exceptions)} of {count} tests failed in child process, " + f"rank={pgi.rank}.") @pytest.mark.parametrize("mnk", PPLX_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.skip(reason="Too slow, run manually for debugging.") @requires_pplx -def test_pplx_moe( +def test_pplx_moe_slow( mnk: tuple[int, int, int], e: int, topk: int, @@ -723,3 +848,15 @@ def test_pplx_moe( parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, use_internode) + + +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("use_internode", [False]) +@requires_pplx +def test_pplx_moe( + world_dp_size: tuple[int, int], + use_internode: bool, +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + parallel_launch(world_size, _pplx_moe_loop, dp_size, use_internode) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0ef4e4f767e3..002b7aefd3d6 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -178,6 +178,9 @@ def run_cutlass_moe_fp8( c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) + # Should this be filled always? + c1.fill_(0) + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, per_act_token, per_out_ch) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9a96c61a8a4a..7a45da3d8ed8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -676,11 +676,12 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.dp_size + num_dp = self.world_size // self.dp_size num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) - return (workspace13, workspace2, workspace13, a.dtype) + output = workspace13 + return (workspace13, workspace2, output, a.dtype) def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: assert self.quant_config.is_quantized diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index dc8c2f66285a..78a02fc9596a 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -32,16 +32,16 @@ def pplx_hidden_dim_scale_bytes( elem_size = torch.float32.itemsize if per_act_token_quant: - # per-token + # per-token (M x 1) assert block_shape is None hidden_scale_bytes = elem_size elif block_shape is not None: - # per-group + # per-group (M x K_tiles) block_size = block_shape[1] num_blocks = cdiv(hidden_dim, block_size) hidden_scale_bytes = num_blocks * elem_size else: - # per-tensor + # per-tensor (1 x 1) hidden_scale_bytes = elem_size else: hidden_dim_bytes = hidden_dim * in_dtype.itemsize @@ -134,6 +134,7 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] if not quant_config.is_grouped: + # TODO (bnell): use group_broadcast instead? a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) assert a1q_scale is None or a1q_scale.ndim == 2, \ @@ -163,23 +164,26 @@ def prepare( float32_size = torch.float32.itemsize if quant_config.is_per_act_token: + # (M x 1) -> (E x M x 1) token_dim = expert_x.size(1) final_dim = expert_x.size(2) assert final_dim % float32_size == 0 elif quant_config.is_per_tensor: + # (1 x 1) -> (E x 1 x 1) token_dim = expert_x.size(1) - final_dim = float32_size + final_dim = 1 else: + # (M x K_tiles) -> (E x M x K_tiles) assert quant_config.block_shape is not None num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) - final_dim = round_up(num_blocks, float32_size) token_dim = expert_x.size(1) + final_dim = num_blocks expert_x_scale_shape = ( num_local_experts, token_dim, - final_dim, + round_up(final_dim, 4) # or 16? ) # TODO (bnell): make sure shape matches up with pplx hidden bytes From 7d891cd0571814112d0689f54ee2fdb5c8a1cf34 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 28 Jun 2025 02:59:33 +0000 Subject: [PATCH 56/77] cleanups Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 28 ++++++------------- .../layers/fused_moe/pplx_prepare_finalize.py | 12 ++------ 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 776c15a45805..bb05174373fe 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -697,25 +697,15 @@ def _pplx_moe( chunked_torch_output = chunk_by_rank( torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - #tol = 6e-2 if quant_dtype is not None else 3e-2 - tol = 3e-2 + torch.testing.assert_close(batched_output, + torch_output, + atol=3e-2, + rtol=3e-2) - try: - torch.testing.assert_close(batched_output, - torch_output, - atol=3e-2, - rtol=3e-2) - - torch.testing.assert_close(pplx_output, - chunked_torch_output, - atol=tol, - rtol=tol) - except Exception: - #torch.set_printoptions(profile="full") - #print(f"PPLX {pplx_output.shape}\n{pplx_output}") - #print(f"TORCH {chunked_torch_output.shape}\n" - # f"{chunked_torch_output}") - raise + torch.testing.assert_close(pplx_output, + chunked_torch_output, + atol=3e-2, + rtol=3e-2) finally: if use_internode: nvshmem_finalize() @@ -758,14 +748,12 @@ def _pplx_moe_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool): if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): - #pytest.skip("Skip quantization test for non-quantized type") print( f"{test_desc} - Skip quantization test for non-quantized type." ) continue if per_act_token_quant and block_shape is not None: - #pytest.skip("Skip illegal quantization combination.") print(f"{test_desc} - Skip illegal quantization combination.") continue diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 78a02fc9596a..b31ed09ad612 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -161,29 +161,23 @@ def prepare( expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: - float32_size = torch.float32.itemsize - if quant_config.is_per_act_token: - # (M x 1) -> (E x M x 1) - token_dim = expert_x.size(1) + # (M x 1) -> (E x M x K) final_dim = expert_x.size(2) - assert final_dim % float32_size == 0 elif quant_config.is_per_tensor: # (1 x 1) -> (E x 1 x 1) - token_dim = expert_x.size(1) final_dim = 1 else: # (M x K_tiles) -> (E x M x K_tiles) assert quant_config.block_shape is not None num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) - token_dim = expert_x.size(1) final_dim = num_blocks expert_x_scale_shape = ( num_local_experts, - token_dim, - round_up(final_dim, 4) # or 16? + expert_x.size(1), + round_up(final_dim, 4) # round up for alignment ) # TODO (bnell): make sure shape matches up with pplx hidden bytes From ca748edb6a8eca0d7c5ca9dfb4445fd7f7a615d0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 28 Jun 2025 04:03:11 +0000 Subject: [PATCH 57/77] use proper experts for test_pplx_moe, naive experts work Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 13 ++------ tests/kernels/moe/test_pplx_moe.py | 47 ++++++++++++++------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 9e885bdc1c88..e4e03bcae6a6 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -207,16 +207,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) -# @pytest.mark.parametrize("m", [6, 16, 199, 200, 256]) -# @pytest.mark.parametrize("n", [2816//2]) -# @pytest.mark.parametrize("k", [2048]) -# @pytest.mark.parametrize("e", [32]) -# @pytest.mark.parametrize("topk", [6]) -# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -# @pytest.mark.parametrize("per_act_token_quant", [False]) -# @pytest.mark.parametrize("block_shape", [None]) - -@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) +@pytest.mark.parametrize("m", [1, 32, 45, 64, 222]) +@pytest.mark.parametrize("n", [128, 512, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index bb05174373fe..38a884e0b36d 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) + BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -59,6 +59,13 @@ TOP_KS = [1, 2, 6] DTYPES = [torch.float8_e4m3fn, torch.bfloat16] +# some of these are failing. +PPLX_COMBOS = [ + (3, 1024, 2048), + (45, 512, 2048), + (222, 2048, 1024), +] + vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -155,7 +162,7 @@ def torch_batched_moe( @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -508,20 +515,12 @@ def pplx_moe( dp_size=dp_size, ) - if False: - experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) - else: - experts = NaiveBatchedExperts( - max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + block_shape=block_shape) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -694,18 +693,20 @@ def _pplx_moe( block_shape=block_shape, ) - chunked_torch_output = chunk_by_rank( - torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + chunked_batch_output = chunk_by_rank( + batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) + tol = 4e-2 if m < 256 else 6e-2 + torch.testing.assert_close(pplx_output, - chunked_torch_output, - atol=3e-2, - rtol=3e-2) + chunked_batch_output, + atol=tol, + rtol=tol) finally: if use_internode: nvshmem_finalize() @@ -727,8 +728,10 @@ def format_result(msg, ex=None): def _pplx_moe_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool): current_platform.seed_everything(7) + #combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, + # [False, True], [None, [128, 128]]) combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, - [False, True], [None, [128, 128]]) + [True], [None]) exceptions = [] count = 0 for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos: From 5eceb6d4c83c3cdbf2fd2d814c8e896837c2ddaa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 28 Jun 2025 20:44:16 +0000 Subject: [PATCH 58/77] fix test flag Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 17 +++++------------ tests/kernels/utils.py | 3 ++- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 38a884e0b36d..3785d61e5929 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -59,13 +59,6 @@ TOP_KS = [1, 2, 6] DTYPES = [torch.float8_e4m3fn, torch.bfloat16] -# some of these are failing. -PPLX_COMBOS = [ - (3, 1024, 2048), - (45, 512, 2048), - (222, 2048, 1024), -] - vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -520,7 +513,9 @@ def pplx_moe( world_size=world_size, dp_size=dp_size, use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + ) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -701,12 +696,10 @@ def _pplx_moe( atol=3e-2, rtol=3e-2) - tol = 4e-2 if m < 256 else 6e-2 - torch.testing.assert_close(pplx_output, chunked_batch_output, - atol=tol, - rtol=tol) + atol=3e-2, + rtol=3e-2) finally: if use_internode: nvshmem_finalize() diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 6b8ac92e60be..fcaa93762856 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1111,7 +1111,8 @@ def torch_experts( out.dtype) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, None, quant_dtype, per_act_token_quant, block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, + block_shape) out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, w2_scale[i], block_shape, From 9b92fee82c49a36458d7f85f08496e096a87cf12 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 28 Jun 2025 21:42:01 +0000 Subject: [PATCH 59/77] re-enable tests + loopify test_pplx_moe tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 178 +++++++++++++++++------------ 1 file changed, 106 insertions(+), 72 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 3785d61e5929..8591d1389f49 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -7,7 +7,7 @@ import itertools import textwrap import traceback -from typing import Optional +from typing import Callable, Optional import pytest import torch @@ -404,8 +404,9 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.skip(reason="Too slow, run manually for debugging.") @requires_pplx -def test_pplx_prepare_finalize( +def test_pplx_prepare_finalize_slow( mnk: tuple[int, int, int], e: int, topk: int, @@ -593,6 +594,7 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, + num_experts: int, w1_s: Optional[torch.Tensor] = None, w2_s: Optional[torch.Tensor] = None, quant_dtype: Optional[torch.dtype] = None, @@ -705,26 +707,79 @@ def _pplx_moe( nvshmem_finalize() -def format_result(msg, ex=None): - if ex is not None: - x = str(ex) - newx = x.strip(" \n\t")[:16] - if len(newx) < len(x): - newx = newx + " ..." +@pytest.mark.parametrize("mnk", PPLX_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) +@pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.skip(reason="Too slow, run manually for debugging.") +@requires_pplx +def test_pplx_moe_slow( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]], + use_internode: bool, +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size - prefix = "E\t" - print(f"{textwrap.indent(traceback.format_exc(), prefix)}") - print(f"FAILED {msg} - {newx}\n") + if dtype == torch.float8_e4m3fn: + use_fp8_w8a8 = True + quant_dtype = dtype else: - print(f"PASSED {msg}") + use_fp8_w8a8 = False + quant_dtype = None + + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): + pytest.skip("Skip quantization test for non-quantized type") + + if per_act_token_quant and block_shape is not None: + pytest.skip("Skip illegal quantization combination") + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + + _, w1, w1_s, _, w2, w2_s = make_test_weights( + e, + n, + k, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + ) + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, + w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, + use_internode) + + +def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, + make_weights: bool, test_fn: Callable): + + def format_result(msg, ex=None): + if ex is not None: + x = str(ex) + newx = x.strip(" \n\t")[:16] + if len(newx) < len(x): + newx = newx + " ..." + + prefix = "E\t" + print(f"{textwrap.indent(traceback.format_exc(), prefix)}") + print(f"FAILED {msg} - {newx}\n") + else: + print(f"PASSED {msg}") -def _pplx_moe_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool): current_platform.seed_everything(7) - #combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, - # [False, True], [None, [128, 128]]) combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, - [True], [None]) + [False, True], [None, [128, 128]]) exceptions = [] count = 0 for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos: @@ -756,19 +811,35 @@ def _pplx_moe_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool): a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - _, w1, w1_s, _, w2, w2_s = make_test_weights( - e, - n, - k, - quant_dtype=quant_dtype, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - ) + args = dict() + if make_weights: + _, w1, w1_s, _, w2, w2_s = make_test_weights( + e, + n, + k, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + ) + args["w1"] = w1 + args["w2"] = w2 + args["w1_s"] = w1_s + args["w2_s"] = w2_s try: - _pplx_moe(pgi, dp_size, a, w1, w2, score, topk, w1_s, w2_s, - quant_dtype, per_act_token_quant, block_shape, - use_internode) + test_fn( + pgi=pgi, + dp_size=dp_size, + a=a, + score=score, + topk=topk, + num_experts=e, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + use_internode=use_internode, + **args, + ) format_result(test_desc) except Exception as ex: format_result(test_desc, ex) @@ -778,60 +849,22 @@ def _pplx_moe_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool): raise RuntimeError( f"{len(exceptions)} of {count} tests failed in child process, " f"rank={pgi.rank}.") + else: + print(f"{count} of {count} tests passed in child process, " + f"rank={pgi.rank}.") -@pytest.mark.parametrize("mnk", PPLX_COMBOS) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) -@pytest.mark.parametrize("per_act_token_quant", [False, True]) -@pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) -@pytest.mark.skip(reason="Too slow, run manually for debugging.") @requires_pplx -def test_pplx_moe_slow( - mnk: tuple[int, int, int], - e: int, - topk: int, - dtype: torch.dtype, +def test_pplx_prepare_finalize( world_dp_size: tuple[int, int], - per_act_token_quant: bool, - block_shape: Optional[list[int]], use_internode: bool, ): current_platform.seed_everything(7) - m, n, k = mnk world_size, dp_size = world_dp_size - - if dtype == torch.float8_e4m3fn: - use_fp8_w8a8 = True - quant_dtype = dtype - else: - use_fp8_w8a8 = False - quant_dtype = None - - if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): - pytest.skip("Skip quantization test for non-quantized type") - - if per_act_token_quant and block_shape is not None: - pytest.skip("Skip illegal quantization combination") - - a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 - score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - - _, w1, w1_s, _, w2, w2_s = make_test_weights( - e, - n, - k, - quant_dtype=quant_dtype, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - ) - - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, - w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, - use_internode) + parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, False, + _pplx_prepare_finalize) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @@ -843,4 +876,5 @@ def test_pplx_moe( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_moe_loop, dp_size, use_internode) + parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, + _pplx_moe) From 8894d0f8a43c5dc100a2e705cee78b924b15939f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 29 Jun 2025 02:28:55 +0000 Subject: [PATCH 60/77] add optional tag to slow tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 8591d1389f49..a073d9b9b139 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -404,7 +404,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) -@pytest.mark.skip(reason="Too slow, run manually for debugging.") +@pytest.mark.optional @requires_pplx def test_pplx_prepare_finalize_slow( mnk: tuple[int, int, int], @@ -715,7 +715,7 @@ def _pplx_moe( @pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) -@pytest.mark.skip(reason="Too slow, run manually for debugging.") +@pytest.mark.optional @requires_pplx def test_pplx_moe_slow( mnk: tuple[int, int, int], From d135b4178ff9068227579ce5ecb0f8ba49219043 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 29 Jun 2025 19:45:20 +0000 Subject: [PATCH 61/77] fix lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index a073d9b9b139..98fa657404a7 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -194,11 +194,15 @@ def rank_chunk(num: int, r: int, w: int) -> int: return (num // w) + (1 if r < rem else 0) -def chunk_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + + +def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int, + w: int) -> Optional[torch.Tensor]: if t is not None: - chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return chunk_by_rank(t, r, w) else: return t @@ -531,8 +535,8 @@ def pplx_moe( # Chunking weights like this only works for batched format w1_chunk = chunk_by_rank(w1, rank, world_size) w2_chunk = chunk_by_rank(w2, rank, world_size) - w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size) - w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size) + w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size) + w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size) a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size) a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size) From 35966a029420aa8868c6e2b807bb7105f531a30d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Jul 2025 01:51:25 +0000 Subject: [PATCH 62/77] tweaks Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7a45da3d8ed8..3c6f7ba96144 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -766,7 +766,7 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (torch.compiler.is_compiling() + if (True or torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. From 8485bdeb258d32a309b1c2c79119fc810c542dcf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Jul 2025 15:25:19 +0000 Subject: [PATCH 63/77] fixes Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 3c6f7ba96144..ea233f7082d3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -676,7 +676,7 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.world_size // self.dp_size + num_dp = self.world_size num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) @@ -766,7 +766,7 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (True or torch.compiler.is_compiling() + if (torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. @@ -805,7 +805,7 @@ def batched_moe_kernel_quantize_input( A_scale = normalize_batched_scales_shape(A_scale, num_experts) for e in range(E): - num_tokens = expert_num_tokens[e] + num_tokens = int(expert_num_tokens[e].item()) if num_tokens > 0: if A_scale is not None: scales = A_scale[e, :min(num_tokens, A_scale.shape[1])] From 70fa1ddf2373ffe606b9df68a1e8a77f32c26520 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Jul 2025 19:31:22 +0000 Subject: [PATCH 64/77] fixup world_size/dp_size params Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_cutlass_moe.py | 13 +- tests/kernels/moe/test_pplx_moe.py | 158 ++++++++++-------- .../layers/fused_moe/batched_deep_gemm_moe.py | 10 +- .../batched_triton_or_deep_gemm_moe.py | 16 +- .../model_executor/layers/fused_moe/config.py | 14 +- .../layers/fused_moe/cutlass_moe.py | 9 +- .../fused_moe/deepep_ht_prepare_finalize.py | 6 +- .../fused_moe/deepep_ll_prepare_finalize.py | 4 - .../layers/fused_moe/fused_batched_moe.py | 34 ++-- vllm/model_executor/layers/fused_moe/layer.py | 27 +-- .../layers/fused_moe/pplx_prepare_finalize.py | 27 +-- .../compressed_tensors_moe.py | 4 +- .../model_executor/layers/quantization/fp8.py | 5 +- 13 files changed, 156 insertions(+), 171 deletions(-) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 58fac246658a..37c5ed008a0b 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform +from vllm.utils import cdiv from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -112,15 +113,17 @@ def pplx_cutlass_moe( w2_scale = w2_scale.to(device) a1_scale = a1_scale.to(device) + assert num_experts % world_size == 0 + num_local_experts = cdiv(num_experts, world_size) + prepare_finalize = PplxPrepareAndFinalize( ata, - max_num_tokens, - pgi.world_size, - rank, - dp_size, + max_num_tokens=max_num_tokens, + num_local_experts=num_local_experts, + num_dispatchers=pgi.world_size // dp_size, ) - experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, + experts = CutlassExpertsFp8(num_local_experts, out_dtype, per_act_token, per_out_ch, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 98fa657404a7..a2352550a12c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -189,6 +189,63 @@ def test_fused_moe_batched_experts( rtol=0) +def create_pplx_prepare_finalize( + num_tokens: int, + hidden_dim: int, + topk: int, + num_experts: int, + rank: int, + dp_size: int, + world_size: int, + in_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + block_shape: Optional[list[int]], + per_act_token_quant: bool, + group_name: Optional[str], +): + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) + + max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1) + num_local_experts = rank_chunk(num_experts, 0, world_size) + + hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( + max_num_tokens, + hidden_dim, + in_dtype, + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + args = dict( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=scale_bytes, + ) + + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens=max_num_tokens, + num_local_experts=num_local_experts, + num_dispatchers=world_size // dp_size, + ) + + return prepare_finalize, ata + + def rank_chunk(num: int, r: int, w: int) -> int: rem = num % w return (num // w) + (1 if r < rem else 0) @@ -240,9 +297,6 @@ def pplx_prepare_finalize( per_act_token_quant: bool, group_name: Optional[str], ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) - assert torch.cuda.current_device() == pgi.local_rank topk = topk_ids.shape[1] @@ -250,43 +304,23 @@ def pplx_prepare_finalize( device = pgi.device rank = pgi.rank world_size = pgi.world_size - max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1) - - hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( - max_num_tokens, - hidden_dim, - a.dtype, - quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) - - args = dict( - max_num_tokens=max_num_tokens, - num_experts=num_experts, - experts_per_token=topk, - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim_bytes, - hidden_dim_scale_bytes=scale_bytes, - ) - - if group_name is None: - ata = AllToAll.internode(**args) - else: - args["group_name"] = group_name - ata = AllToAll.intranode(**args) + print(f"PGI {pgi} {world_size} {dp_size}") topk_ids = topk_ids.to(dtype=torch.uint32) - prepare_finalize = PplxPrepareAndFinalize( - ata, - max_num_tokens, - world_size, + prepare_finalize, ata = create_pplx_prepare_finalize( + num_tokens, + hidden_dim, + topk, + num_experts, rank, dp_size, + world_size, + a.dtype, + quant_dtype, + block_shape, + per_act_token_quant, + group_name, ) assert a.shape[0] == topk_ids.shape[0] @@ -468,51 +502,29 @@ def pplx_moe( use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) - hidden_dim = a.shape[1] + num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] topk = topk_ids.shape[1] max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16) - hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( - max_num_tokens, + prepare_finalize, ata = create_pplx_prepare_finalize( + num_tokens, hidden_dim, + topk, + num_experts, + rank, + dp_size, + world_size, a.dtype, quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) - - args = dict( - max_num_tokens=max_num_tokens, - num_experts=num_experts, - experts_per_token=topk, - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim_bytes, - hidden_dim_scale_bytes=scale_bytes, + block_shape, + per_act_token_quant, + group_name, ) - if group_name is None: - ata = AllToAll.internode(**args) - else: - args["group_name"] = group_name - ata = AllToAll.intranode(**args) - topk_ids = topk_ids.to(dtype=torch.uint32) - prepare_finalize = PplxPrepareAndFinalize( - ata, - max_num_tokens=max_num_tokens, - world_size=world_size, - rank=rank, - dp_size=dp_size, - ) - experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, world_size=world_size, @@ -858,7 +870,8 @@ def format_result(msg, ex=None): f"rank={pgi.rank}.") -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) +#@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_prepare_finalize( @@ -867,11 +880,12 @@ def test_pplx_prepare_finalize( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, False, - _pplx_prepare_finalize) + parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, + use_internode, False, _pplx_prepare_finalize) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) +#@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_moe( diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 6b08f32dff18..676629b38af3 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -184,14 +184,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, max_num_tokens: int, - world_size: int, - dp_size: int, + num_dispatchers: int, block_shape: list[int], per_act_token_quant=False): """ max_num_tokens: Maximum number of tokens from a DP Rank - world_size: Number of EP ranks - dp_size: Number of data-parallel ranks block_shape: Block quantization block shape """ super().__init__( @@ -202,8 +199,7 @@ def __init__(self, )) assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE self.max_num_tokens = max_num_tokens - self.world_size = world_size - self.dp_size = dp_size + self.num_dispatchers = num_dispatchers @property def activation_formats( @@ -233,7 +229,7 @@ def workspace_shapes( # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks # end up sending their tokens. This needs to be fixed. - num_dispatchers = self.world_size + num_dispatchers = self.num_dispatchers num_experts = local_num_experts max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 65bd4f49b57f..7d5c04f2560c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -15,8 +15,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, max_num_tokens: int, - world_size: int, - dp_size: int, + num_dispatchers: int, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -37,15 +36,11 @@ def __init__(self, block_shape=block_shape, per_act_token_quant=per_act_token_quant, )) - self.max_num_tokens = max_num_tokens - self.world_size = world_size - self.dp_size = dp_size self.allow_deep_gemm = allow_deep_gemm self.batched_triton_experts = BatchedTritonExperts( - max_num_tokens=self.max_num_tokens, - world_size=self.world_size, - dp_size=self.dp_size, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, @@ -59,9 +54,8 @@ def __init__(self, == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) self.batched_deep_gemm_experts = BatchedDeepGemmExperts( - max_num_tokens=self.max_num_tokens, - world_size=self.world_size, - dp_size=self.dp_size, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, block_shape=self.block_shape, # type: ignore[arg-type] ) if self.allow_deep_gemm else None diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 083285d559de..e3b8e1cf2c3c 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -161,7 +161,7 @@ class FusedMoEParallelConfig: tp_rank: int dp_rank: int ep_rank: int - world_size: int + num_dispatchers: int use_ep: bool # whether to use EP or not @@ -185,7 +185,7 @@ def use_deepep_ll_kernels(self): and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") @staticmethod - def make(tp_size_: int, dp_size_: int, world_size_: int, + def make(tp_size_: int, dp_size_: int, num_dispatchers_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input tp_size_, @@ -196,7 +196,7 @@ def make(tp_size_: int, dp_size_: int, world_size_: int, tp_size_ (int): tp_size passed into the FusedMoE constructor. dp_size_ (int): dp_size passed into the FusedMoE constructor. ep_size_ (int): ep_size passed into the FusedMoE constructor. - world_size_ (int): the world size of the current All2All manager. + num_dispatchers_ (int): the number of DP dispatchers. vllm_parallel_config (ParallelConfig): vllm's parallel config object. @@ -275,7 +275,7 @@ def flatten_tp_across_dp(dp_rank: int): dp_rank=dp_rank, ep_size=1, ep_rank=0, - world_size=world_size_, + num_dispatchers=num_dispatchers_, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -289,7 +289,7 @@ def flatten_tp_across_dp(dp_rank: int): dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, - world_size=world_size_, + num_dispatchers=num_dispatchers_, use_ep=True) @@ -358,8 +358,8 @@ def ep_size(self): return self.moe_parallel_config.ep_size @property - def world_size(self): - return self.moe_parallel_config.world_size + def num_dispatchers(self): + return self.moe_parallel_config.num_dispatchers @property def tp_rank(self): diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 002b7aefd3d6..bffac712dbfd 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -212,6 +212,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_experts_per_worker: int, + num_dispatchers: int, out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, @@ -227,6 +228,7 @@ def __init__( )) assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker + self.num_dispatchers = num_dispatchers self.out_dtype = out_dtype self.use_batched_format = use_batched_format @@ -263,8 +265,11 @@ def workspace_shapes( output: tuple[int, ...] = () if self.use_batched_format: padded_M = aq.size(1) - workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) + num_dp = self.num_dispatchers + workspace1 = (self.max_experts_per_worker, padded_M * num_dp, + max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, + (N // 2)) output = (self.max_experts_per_worker, padded_M, K) else: workspace1 = (M * topk, max(2 * N, K)) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 87c304ee1a95..bba89d3406d6 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -16,12 +16,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ - def __init__(self, buffer: deep_ep.Buffer, world_size: int, rank: int, - dp_size: int, rank_expert_offset: int): + def __init__(self, buffer: deep_ep.Buffer, rank: int, dp_size: int, + rank_expert_offset: int): super().__init__() self.buffer = buffer - self.world_size = world_size - self.rank = rank self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset # The dispatch function returns a handle that the combine function diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index f65f74a9444b..dec924d9d65b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -42,15 +42,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__(self, buffer: deep_ep.Buffer, max_tokens_per_rank: int, - world_size: int, - dp_size: int, use_fp8_dispatch: bool = False): super().__init__() self.buffer = buffer self.max_tokens_per_rank = max_tokens_per_rank - self.world_size = world_size - self.dp_size = dp_size self.use_fp8_dispatch = use_fp8_dispatch # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index ea233f7082d3..38b4336d5c96 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -471,15 +471,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, max_num_tokens: int, - world_size: int, - dp_size: int, + num_local_experts: int, rank: int, ): super().__init__() - self.world_size = world_size - self.dp_size = dp_size - self.rank = rank self.max_num_tokens = max_num_tokens + self.num_local_experts = num_local_experts + self.rank = rank @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -522,9 +520,7 @@ def prepare( dtype=torch.int, device=a1.device) - assert num_experts % self.world_size == 0 - - num_local_experts = num_experts // self.world_size + num_local_experts = self.num_local_experts if quant_config.quant_dtype is None: b_type = a1.dtype @@ -626,8 +622,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_num_tokens: int, - world_size: int, - dp_size: int, + num_dispatchers: int, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -648,8 +643,7 @@ def __init__( assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" self.max_num_tokens = max_num_tokens - self.world_size = world_size - self.dp_size = dp_size + self.num_dispatchers = num_dispatchers @property def activation_formats( @@ -676,7 +670,7 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.world_size + num_dp = self.num_dispatchers # global // local? num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) @@ -834,8 +828,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_num_tokens: int, - world_size: int, - dp_size: int, + num_dispatchers: int, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -855,17 +848,14 @@ def __init__( assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + assert max_num_tokens > 0 + assert num_dispatchers > 0 self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int8_w8a8 = use_int8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 self.max_num_tokens = max_num_tokens - self.world_size = world_size - self.dp_size = dp_size - assert world_size > 0 - assert dp_size > 0 - assert dp_size <= world_size - assert max_num_tokens > 0 + self.num_dispatchers = num_dispatchers @property def activation_formats( @@ -892,7 +882,7 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.world_size + num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 80583caad7f0..74b3fcbdd3b0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -114,6 +114,9 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, hidden_dim_scale_bytes=hidden_scale_bytes, ) + assert (all2all_manager.world_size // + all2all_manager.tp_group.world_size) == moe.num_dispatchers + # Intranode pplx a2a takes a group name while internode does not. if not all2all_manager.internode: all_to_all_args[ @@ -124,10 +127,8 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, - world_size=all2all_manager.world_size, - rank=all2all_manager.rank, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, + num_local_experts=moe.num_local_experts, + num_dispatchers=moe.num_dispatchers, ) elif moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size @@ -136,16 +137,12 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = DeepEPHTPrepareAndFinalize( handle, - world_size=all2all_manager.world_size, - rank=all2all_manager.rank, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, ) elif moe.use_deepep_ll_kernels: - assert moe.dp_size == all2all_manager.dp_world_size - all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, @@ -168,8 +165,6 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize = DeepEPLLPrepareAndFinalize( handle, max_tokens_per_rank=moe.max_num_tokens, - world_size=all2all_manager.world_size, - dp_size=all2all_manager.dp_world_size, use_fp8_dispatch=use_fp8_dispatch, ) @@ -654,12 +649,14 @@ def __init__( if dp_size is not None else get_dp_group().world_size) world_size_ = get_world_group().world_size + num_dispatchers = world_size_ // tp_size_ + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( tp_size_=tp_size_, dp_size_=dp_size_, - world_size_=world_size_, + num_dispatchers_=num_dispatchers, vllm_parallel_config=vllm_config.parallel_config)) self.global_num_experts = num_experts + num_redundant_experts @@ -1332,8 +1329,12 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + # TBD + if hidden_states.shape[0] < envs.VLLM_FUSED_MOE_CHUNK_SIZE: + return self.forward_impl(hidden_states, router_logits) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b31ed09ad612..aae3f6e1c8d8 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -53,25 +53,22 @@ def pplx_hidden_dim_scale_bytes( ) -# The max_num_tokens, world_size and dp_size must be the same -# as the ones used to create the AllToAll. class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, a2a: pplx.AllToAll, max_num_tokens: int, - world_size: int, - rank: int, - dp_size: int, + num_local_experts: int, + num_dispatchers: int, ): super().__init__() assert max_num_tokens > 0 + assert num_local_experts > 0 self.a2a = a2a self.max_num_tokens = max_num_tokens - self.world_size = world_size - self.rank = rank - self.dp_size = dp_size + self.num_local_experts = num_local_experts + self.num_dispatchers = num_dispatchers @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -140,21 +137,15 @@ def prepare( assert a1q_scale is None or a1q_scale.ndim == 2, \ f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" - # rem_experts need to be 0 for pplx to work properly. - rem_experts = num_experts % self.world_size - assert rem_experts == 0 - num_local_experts = ((num_experts // self.world_size) + - (1 if self.rank < rem_experts else 0)) - expert_num_tokens = torch.empty( - num_local_experts, + self.num_local_experts, dtype=torch.int32, device=device, ) - num_dp = self.world_size // self.dp_size expert_x = torch.empty( - (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), + (self.num_local_experts, + self.max_num_tokens * self.num_dispatchers, hidden_dim), dtype=a1q.dtype, device=device, ) @@ -175,7 +166,7 @@ def prepare( final_dim = num_blocks expert_x_scale_shape = ( - num_local_experts, + self.num_local_experts, expert_x.size(1), round_up(final_dim, 4) # round up for alignment ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 581991d3e620..b396fb354cdf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -592,8 +592,7 @@ def select_gemm_impl( return BatchedTritonExperts( max_num_tokens=max_num_tokens_per_rank, - world_size=moe.world_size, - dp_size=moe.dp_size, + num_dispatchers=moe.num_dispatchers, use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, per_act_token_quant=( @@ -867,6 +866,7 @@ def select_gemm_impl( experts = CutlassExpertsFp8( num_experts, + moe.num_dispatchers, moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 612eb99a124f..7cbf65db9b76 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -800,10 +800,7 @@ def select_gemm_impl( self.quant_config.weight_block_size, False) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, - world_size=prepare_finalize. - world_size, # type: ignore [attr-defined] - dp_size=prepare_finalize. - dp_size, # type: ignore [attr-defined] + num_dispatchers=moe.num_dispatchers, use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, per_act_token_quant=False, From 9c562063b185c6eb3b7a7534c1e6797bfd0f1430 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Jul 2025 20:15:10 +0000 Subject: [PATCH 65/77] fix tests Signed-off-by: Bill Nell --- tests/kernels/moe/parallel_utils.py | 4 ---- tests/kernels/moe/test_deepep_deepgemm_moe.py | 3 +-- tests/kernels/moe/test_pplx_cutlass_moe.py | 5 +++-- tests/kernels/moe/test_pplx_moe.py | 11 +++++------ tests/kernels/moe/utils.py | 12 ++++-------- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 4 +++- 6 files changed, 16 insertions(+), 23 deletions(-) diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index 7797e4f0c9c0..d2bf21c02bd5 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -137,7 +137,6 @@ def make_deepep_ht_a2a(pg: ProcessGroup, low_latency_mode=low_latency_mode, num_qps_per_rank=num_qps_per_rank) return DeepEPHTPrepareAndFinalize(buffer=buffer, - world_size=pgi.world_size, rank=pgi.rank, dp_size=dp_size, rank_expert_offset=pgi.rank * @@ -146,7 +145,6 @@ def make_deepep_ht_a2a(pg: ProcessGroup, def make_deepep_ll_a2a(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, deepep_ll_args: DeepEPLLArgs, q_dtype: Optional[torch.dtype] = None, block_shape: Optional[list[int]] = None): @@ -166,8 +164,6 @@ def make_deepep_ll_a2a(pg: ProcessGroup, return DeepEPLLPrepareAndFinalize( buffer=buffer, - world_size=pgi.world_size, - dp_size=dp_size, max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, ) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 9b861d4ebc23..23eb5fcc9453 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -148,8 +148,7 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, fused_experts = BatchedDeepGemmExperts( max_num_tokens=max_tokens_per_rank, - world_size=pgi.world_size, - dp_size=dp_size, + num_dispatchers=pgi.world_size // dp_size, block_shape=test_config.block_size, per_act_token_quant=test_config.per_act_token_quant) mk = FusedMoEModularKernel(prepare_finalize=a2a, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 37c5ed008a0b..e4f4a393dfd5 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -115,18 +115,19 @@ def pplx_cutlass_moe( assert num_experts % world_size == 0 num_local_experts = cdiv(num_experts, world_size) + num_dispatchers = pgi.world_size // dp_size prepare_finalize = PplxPrepareAndFinalize( ata, max_num_tokens=max_num_tokens, num_local_experts=num_local_experts, - num_dispatchers=pgi.world_size // dp_size, - ) + num_dispatchers=num_dispatchers) experts = CutlassExpertsFp8(num_local_experts, out_dtype, per_act_token, per_out_ch, + num_dispatchers=num_dispatchers, use_batched_format=True) fused_cutlass_experts = FusedMoEModularKernel( diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index a2352550a12c..a105c901d37f 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -527,8 +527,7 @@ def pplx_moe( experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, + num_dispatchers=prepare_finalize.num_dispatchers, use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, block_shape=block_shape, per_act_token_quant=per_act_token_quant, @@ -870,8 +869,8 @@ def format_result(msg, ex=None): f"rank={pgi.rank}.") -@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) -#@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +#@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_prepare_finalize( @@ -884,8 +883,8 @@ def test_pplx_prepare_finalize( use_internode, False, _pplx_prepare_finalize) -@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) -#@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +#@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_moe( diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 5b1048797447..b7b16b7f2e35 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -63,13 +63,11 @@ def batched_moe( fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, - world_size=1, - dp_size=1, + num_local_experts=w1.shape[0], rank=0), BatchedTritonExperts( max_num_tokens=max_num_tokens, - world_size=1, - dp_size=1, + num_dispatchers=1, use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, per_act_token_quant=per_act_token_quant, block_shape=block_shape, @@ -105,13 +103,11 @@ def naive_batched_moe( fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, - world_size=1, - dp_size=1, + num_local_experts=w1.shape[0], rank=0), NaiveBatchedExperts( max_num_tokens=max_num_tokens, - dp_size=1, - world_size=1, + num_dispatchers=1, use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, per_act_token_quant=per_act_token_quant, block_shape=block_shape, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index bffac712dbfd..41063706c361 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -212,11 +212,11 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_experts_per_worker: int, - num_dispatchers: int, out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, block_shape: Optional[list[int]] = None, + num_dispatchers: Optional[int] = None, use_batched_format: bool = False, ): super().__init__( @@ -227,6 +227,7 @@ def __init__( block_shape=block_shape, )) assert max_experts_per_worker > 0 + assert not use_batched_format or num_dispatchers is not None self.max_experts_per_worker = max_experts_per_worker self.num_dispatchers = num_dispatchers self.out_dtype = out_dtype @@ -266,6 +267,7 @@ def workspace_shapes( if self.use_batched_format: padded_M = aq.size(1) num_dp = self.num_dispatchers + assert num_dp is not None workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) workspace2 = (self.max_experts_per_worker, padded_M * num_dp, From 653942fa9a914bf471a5a2f133c56207f2f1ef84 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Jul 2025 20:50:49 +0000 Subject: [PATCH 66/77] more test fixes Signed-off-by: Bill Nell --- tests/kernels/moe/parallel_utils.py | 3 +-- tests/kernels/moe/test_batched_moe.py | 11 ++--------- tests/kernels/moe/test_deepep_moe.py | 5 +++-- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index d2bf21c02bd5..497234a9cd6e 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -182,5 +182,4 @@ def make_deepep_a2a(pg: ProcessGroup, block_shape) assert deepep_ll_args is not None - return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, - block_shape) + return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index e4e03bcae6a6..3582ff6dc4ea 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -33,12 +33,10 @@ (45, 512, 512), (45, 1024, 128), (45, 1024, 2048), - (64, 128, 128), (64, 512, 512), (64, 1024, 2048), (222, 128, 128), (222, 128, 2048), - (222, 512, 512), (222, 1024, 128), (222, 1024, 2048), ] @@ -92,10 +90,7 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize("num_experts", [8, 16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", - [32, 64, 128, 192, 224, 256, 512]) -@pytest.mark.parametrize("K", [128, 256, 1024]) -@pytest.mark.parametrize("N", [128, 256, 512, 1024]) +@pytest.mark.parametrize(("max_tokens_per_expert", "N", "K"), MNK_FACTORS) @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) @@ -207,9 +202,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) -@pytest.mark.parametrize("m", [1, 32, 45, 64, 222]) -@pytest.mark.parametrize("n", [128, 512, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024, 2048]) +@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index d7df5bf77035..6446a8d9503e 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -154,12 +154,13 @@ def make_modular_kernel( deepep_ht_args = ht_args, deepep_ll_args = ll_args) + num_dispatchers = pgi.world_size // dp_size + if low_latency_mode: assert not per_act_token_quant, "not supported in ll mode" fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, - world_size=pgi.world_size, - dp_size=dp_size, + num_dispatchers=num_dispatchers, use_fp8_w8a8=is_quantized, use_int8_w8a8=False, use_int8_w8a16=False, From d2dd40567170549f41bfed88df34cba274d6d913 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Jul 2025 20:52:55 +0000 Subject: [PATCH 67/77] fix merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 3582ff6dc4ea..ce523d0994fb 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -90,7 +90,10 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize("num_experts", [8, 16, 32]) -@pytest.mark.parametrize(("max_tokens_per_expert", "N", "K"), MNK_FACTORS) +@pytest.mark.parametrize("max_tokens_per_expert", + [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("K", [128, 256, 1024]) +@pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) From ae91a5e9302c07052c926d9ae3c116b66200495f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 02:23:51 +0000 Subject: [PATCH 68/77] trim testcases Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index ce523d0994fb..c9a4375ac939 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -93,7 +93,7 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) -@pytest.mark.parametrize("N", [128, 256, 512, 1024]) +@pytest.mark.parametrize("N", [128, 256, 1024]) @pytest.mark.parametrize( "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) From 76c697a41a85c2622aa56915e6ea4d734da6558c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 02:37:07 +0000 Subject: [PATCH 69/77] fix lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 74b3fcbdd3b0..f4fdfcb1fd3b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -137,6 +137,7 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = DeepEPHTPrepareAndFinalize( handle, + rank=all2all_manager.rank, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, @@ -240,18 +241,12 @@ def select_gemm_impl( assert self.fused_experts == fused_experts - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) - assert self.moe.dp_size == all2all_manager.dp_world_size return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, + num_dispatchers=self.moe.num_dispatchers, ) else: logger.debug("TritonExperts %s", self.moe) From a5c8e85fa6614740bb701690c2101bb8abfdaac0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 16:04:10 +0000 Subject: [PATCH 70/77] ping Signed-off-by: Bill Nell From 285b2bc8186baa94af179840f5492db55e7c7c33 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 18:15:24 +0000 Subject: [PATCH 71/77] fix num_dispatchers for TP+DP Signed-off-by: Bill Nell --- tests/kernels/moe/parallel_utils.py | 2 ++ .../model_executor/layers/fused_moe/config.py | 10 +------- .../fused_moe/deepep_ht_prepare_finalize.py | 8 +++++-- .../fused_moe/deepep_ll_prepare_finalize.py | 5 ++++ .../layers/fused_moe/fused_batched_moe.py | 7 +++++- vllm/model_executor/layers/fused_moe/layer.py | 23 +++++++------------ .../layers/fused_moe/modular_kernel.py | 4 ++++ .../layers/fused_moe/pplx_prepare_finalize.py | 7 ++++-- .../layers/fused_moe/prepare_finalize.py | 3 +++ .../compressed_tensors_moe.py | 4 ++-- .../model_executor/layers/quantization/fp8.py | 2 +- 11 files changed, 43 insertions(+), 32 deletions(-) diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index 497234a9cd6e..e06856d5dc61 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -137,6 +137,7 @@ def make_deepep_ht_a2a(pg: ProcessGroup, low_latency_mode=low_latency_mode, num_qps_per_rank=num_qps_per_rank) return DeepEPHTPrepareAndFinalize(buffer=buffer, + num_dispatchers=pgi.world_size, rank=pgi.rank, dp_size=dp_size, rank_expert_offset=pgi.rank * @@ -164,6 +165,7 @@ def make_deepep_ll_a2a(pg: ProcessGroup, return DeepEPLLPrepareAndFinalize( buffer=buffer, + num_dispatchers=pgi.world_size, max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index e3b8e1cf2c3c..2aaf275fb11d 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -161,7 +161,6 @@ class FusedMoEParallelConfig: tp_rank: int dp_rank: int ep_rank: int - num_dispatchers: int use_ep: bool # whether to use EP or not @@ -185,7 +184,7 @@ def use_deepep_ll_kernels(self): and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") @staticmethod - def make(tp_size_: int, dp_size_: int, num_dispatchers_: int, + def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input tp_size_, @@ -196,7 +195,6 @@ def make(tp_size_: int, dp_size_: int, num_dispatchers_: int, tp_size_ (int): tp_size passed into the FusedMoE constructor. dp_size_ (int): dp_size passed into the FusedMoE constructor. ep_size_ (int): ep_size passed into the FusedMoE constructor. - num_dispatchers_ (int): the number of DP dispatchers. vllm_parallel_config (ParallelConfig): vllm's parallel config object. @@ -275,7 +273,6 @@ def flatten_tp_across_dp(dp_rank: int): dp_rank=dp_rank, ep_size=1, ep_rank=0, - num_dispatchers=num_dispatchers_, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -289,7 +286,6 @@ def flatten_tp_across_dp(dp_rank: int): dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, - num_dispatchers=num_dispatchers_, use_ep=True) @@ -357,10 +353,6 @@ def dp_size(self): def ep_size(self): return self.moe_parallel_config.ep_size - @property - def num_dispatchers(self): - return self.moe_parallel_config.num_dispatchers - @property def tp_rank(self): return self.moe_parallel_config.tp_rank diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index bba89d3406d6..8b7ef5e20927 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -16,10 +16,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ - def __init__(self, buffer: deep_ep.Buffer, rank: int, dp_size: int, - rank_expert_offset: int): + def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, rank: int, + dp_size: int, rank_expert_offset: int): super().__init__() self.buffer = buffer + self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset # The dispatch function returns a handle that the combine function @@ -30,6 +31,9 @@ def __init__(self, buffer: deep_ep.Buffer, rank: int, dp_size: int, # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index dec924d9d65b..9ce94e045d42 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -42,6 +42,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__(self, buffer: deep_ep.Buffer, max_tokens_per_rank: int, + num_dispatchers: int, use_fp8_dispatch: bool = False): super().__init__() @@ -52,6 +53,10 @@ def __init__(self, # requires. We store the handle here so it is available to the # combine function. self.handle = None + self.num_dispatchers_ = num_dispatchers + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ @property def activation_format(self) -> mk.FusedMoEActivationFormat: diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 38b4336d5c96..98b3f89b5930 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -472,12 +472,14 @@ def __init__( self, max_num_tokens: int, num_local_experts: int, + num_dispatchers: int, rank: int, ): super().__init__() self.max_num_tokens = max_num_tokens self.num_local_experts = num_local_experts self.rank = rank + self.num_dispatchers_ = num_dispatchers @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -489,6 +491,9 @@ def max_num_tokens_per_rank(self) -> Optional[int]: def topk_indices_dtype(self) -> Optional[torch.dtype]: return None + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + def prepare( self, a1: torch.Tensor, @@ -670,7 +675,7 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.num_dispatchers # global // local? + num_dp = self.num_dispatchers num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f4fdfcb1fd3b..80cb9975ca8d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -14,7 +14,6 @@ from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tensor_model_parallel_world_size, - get_world_group, tensor_model_parallel_all_reduce) from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context @@ -114,8 +113,8 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, hidden_dim_scale_bytes=hidden_scale_bytes, ) - assert (all2all_manager.world_size // - all2all_manager.tp_group.world_size) == moe.num_dispatchers + num_dispatchers = (all2all_manager.world_size // + all2all_manager.tp_group.world_size) # Intranode pplx a2a takes a group name while internode does not. if not all2all_manager.internode: @@ -128,7 +127,7 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, handle, max_num_tokens=moe.max_num_tokens, num_local_experts=moe.num_local_experts, - num_dispatchers=moe.num_dispatchers, + num_dispatchers=num_dispatchers, ) elif moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size @@ -137,6 +136,7 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = DeepEPHTPrepareAndFinalize( handle, + num_dispatchers=all2all_manager.world_size, rank=all2all_manager.rank, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * @@ -166,6 +166,7 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize = DeepEPLLPrepareAndFinalize( handle, max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, ) @@ -246,7 +247,7 @@ def select_gemm_impl( logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, - num_dispatchers=self.moe.num_dispatchers, + num_dispatchers=prepare_finalize.num_dispatchers(), ) else: logger.debug("TritonExperts %s", self.moe) @@ -642,16 +643,12 @@ def __init__( get_tensor_model_parallel_world_size()) dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size) - world_size_ = get_world_group().world_size - - num_dispatchers = world_size_ // tp_size_ vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( tp_size_=tp_size_, dp_size_=dp_size_, - num_dispatchers_=num_dispatchers, vllm_parallel_config=vllm_config.parallel_config)) self.global_num_experts = num_experts + num_redundant_experts @@ -1324,12 +1321,8 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - # TBD - if hidden_states.shape[0] < envs.VLLM_FUSED_MOE_CHUNK_SIZE: - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2ffb4d328eca..f332b5168913 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -193,6 +193,10 @@ def max_num_tokens_per_rank(self) -> Optional[int]: """ raise NotImplementedError + @abstractmethod + def num_dispatchers(self) -> int: + raise NotImplementedError + class FusedMoEPermuteExpertsUnpermute(ABC): """ diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index aae3f6e1c8d8..33bab9cd0079 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -68,7 +68,7 @@ def __init__( self.a2a = a2a self.max_num_tokens = max_num_tokens self.num_local_experts = num_local_experts - self.num_dispatchers = num_dispatchers + self.num_dispatchers_ = num_dispatchers @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -80,6 +80,9 @@ def max_num_tokens_per_rank(self) -> Optional[int]: def topk_indices_dtype(self) -> Optional[torch.dtype]: return torch.uint32 + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + def prepare( self, a1: torch.Tensor, @@ -145,7 +148,7 @@ def prepare( expert_x = torch.empty( (self.num_local_experts, - self.max_num_tokens * self.num_dispatchers, hidden_dim), + self.max_num_tokens * self.num_dispatchers(), hidden_dim), dtype=a1q.dtype, device=device, ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 9e4be82f6c1f..e1114efe5a3f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -24,6 +24,9 @@ def max_num_tokens_per_rank(self) -> Optional[int]: def topk_indices_dtype(self) -> Optional[torch.dtype]: return None + def num_dispatchers(self) -> int: + return 1 + def prepare( self, a1: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b396fb354cdf..8351326ce7c5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -592,7 +592,7 @@ def select_gemm_impl( return BatchedTritonExperts( max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=moe.num_dispatchers, + num_dispatchers=prepare_finalize.num_dispatchers(), use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, per_act_token_quant=( @@ -866,7 +866,7 @@ def select_gemm_impl( experts = CutlassExpertsFp8( num_experts, - moe.num_dispatchers, + prepare_finalize.num_dispatchers(), moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7cbf65db9b76..26be4f8516dd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -800,7 +800,7 @@ def select_gemm_impl( self.quant_config.weight_block_size, False) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=moe.num_dispatchers, + num_dispatchers=prepare_finalize.num_dispatchers(), use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, per_act_token_quant=False, From 286d988c7595c55dd2a9ae8e35d0c8d574fdc43b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 18:24:39 +0000 Subject: [PATCH 72/77] fix unit test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- tests/kernels/moe/utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index a105c901d37f..da23c83bb02a 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -527,7 +527,7 @@ def pplx_moe( experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers, + num_dispatchers=prepare_finalize.num_dispatchers(), use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, block_shape=block_shape, per_act_token_quant=per_act_token_quant, diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index b7b16b7f2e35..df89ad7e6da6 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -63,6 +63,7 @@ def batched_moe( fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, + num_dispatchers=1, num_local_experts=w1.shape[0], rank=0), BatchedTritonExperts( @@ -103,6 +104,7 @@ def naive_batched_moe( fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, + num_dispatchers=1, num_local_experts=w1.shape[0], rank=0), NaiveBatchedExperts( From 14542e586063fbe0d3e920cd3f3d9d9b80c2e052 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 21:05:36 +0000 Subject: [PATCH 73/77] review comments Signed-off-by: Bill Nell --- tests/kernels/moe/parallel_utils.py | 1 - tests/kernels/moe/test_pplx_moe.py | 3 --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 4 +++- vllm/model_executor/layers/fused_moe/config.py | 4 ++-- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 1 - .../layers/fused_moe/deepep_ht_prepare_finalize.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 1 - vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py | 2 +- 8 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index e06856d5dc61..98ae4c8cd34e 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -138,7 +138,6 @@ def make_deepep_ht_a2a(pg: ProcessGroup, num_qps_per_rank=num_qps_per_rank) return DeepEPHTPrepareAndFinalize(buffer=buffer, num_dispatchers=pgi.world_size, - rank=pgi.rank, dp_size=dp_size, rank_expert_offset=pgi.rank * ht_args.num_local_experts) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index da23c83bb02a..d28e0e040629 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -304,7 +304,6 @@ def pplx_prepare_finalize( device = pgi.device rank = pgi.rank world_size = pgi.world_size - print(f"PGI {pgi} {world_size} {dp_size}") topk_ids = topk_ids.to(dtype=torch.uint32) @@ -869,7 +868,6 @@ def format_result(msg, ex=None): f"rank={pgi.rank}.") -#@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx @@ -883,7 +881,6 @@ def test_pplx_prepare_finalize( use_internode, False, _pplx_prepare_finalize) -#@pytest.mark.parametrize("world_dp_size", [[2, 1], [2, 2], [4, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 676629b38af3..1a1ccf0aaa85 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -189,7 +189,9 @@ def __init__(self, per_act_token_quant=False): """ max_num_tokens: Maximum number of tokens from a DP Rank - block_shape: Block quantization block shape + num_dispatchers: The number of DP dispatchers. + block_shape: Block quantization block shape. + per_act_token_quant: Per activation token quantization flag. """ super().__init__( FusedMoEQuantConfig( diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 2aaf275fb11d..6c03732030d1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -82,7 +82,7 @@ def is_per_act_token(self) -> bool: return self.per_act_token_quant @property - def is_grouped(self) -> bool: + def is_block_quantized(self) -> bool: return self.block_shape is not None @property @@ -95,7 +95,7 @@ def scale_shape( hidden_dim: int, ) -> Optional[tuple[int, int]]: if self.is_quantized: - if self.is_grouped: + if self.is_block_quantized: assert self.block_shape is not None _, block_k = self.block_shape k_tiles = cdiv(hidden_dim, block_k) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 41063706c361..56abe700daf9 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -178,7 +178,6 @@ def run_cutlass_moe_fp8( c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) - # Should this be filled always? c1.fill_(0) ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 8b7ef5e20927..37998334327f 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -16,7 +16,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ - def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, rank: int, + def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, dp_size: int, rank_expert_offset: int): super().__init__() self.buffer = buffer diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 80cb9975ca8d..648dfca374c5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -137,7 +137,6 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize = DeepEPHTPrepareAndFinalize( handle, num_dispatchers=all2all_manager.world_size, - rank=all2all_manager.rank, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 33bab9cd0079..6eb1c0152747 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -133,7 +133,7 @@ def prepare( orig_a_scale_block_shape = a1q_scale.shape[-1] - if not quant_config.is_grouped: + if not quant_config.is_block_quantized: # TODO (bnell): use group_broadcast instead? a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) From 2a9628945f4e0830961c5f1e91760fa59bb04986 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 21:37:05 +0000 Subject: [PATCH 74/77] remove debug cruft Signed-off-by: Bill Nell --- vllm/model_executor/layers/quantization/fp8.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 26be4f8516dd..53734a2393f0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -902,8 +902,6 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) else: - #print(f"A1_SCALE = {layer.w13_input_scale}") - #print(f"A2_SCALE = {layer.w2_input_scale}") return self.fused_experts( hidden_states=x, w1=layer.w13_weight, From 562bb3ef0c2714f3da72179c05c46c706ed2723d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Jul 2025 23:24:01 +0000 Subject: [PATCH 75/77] review comments + scout fix Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_batched_moe.py | 70 +++++++++---------- .../layers/fused_moe/pplx_prepare_finalize.py | 2 - .../compressed_tensors_moe.py | 2 +- vllm/model_executor/models/granitemoe.py | 7 -- 4 files changed, 36 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 98b3f89b5930..0355abbf1d2b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -30,14 +30,14 @@ def moe_mmk( # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). - stride_ak, - stride_bk, - stride_ase, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, # Offsets and masks offs_m, offs_n, @@ -150,18 +150,18 @@ def expert_triton_kernel( b_scale_ptr, b_zp_ptr, # strides - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_ase, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, + stride_am: tl.int64, + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, # offsets offs_bn, # Blockwise quantization data @@ -248,21 +248,21 @@ def batched_triton_kernel( # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). - stride_ae, - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_ce, - stride_cm, - stride_cn, - stride_ase, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, + stride_ae: tl.int64, + stride_am: tl.int64, + stride_ak: tl.int64, + stride_be: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_ce: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, # Blockwise quantization data group_n: tl.constexpr, group_k: tl.constexpr, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 6eb1c0152747..112305a4f2d0 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -174,8 +174,6 @@ def prepare( round_up(final_dim, 4) # round up for alignment ) - # TODO (bnell): make sure shape matches up with pplx hidden bytes - expert_x_scale = torch.empty( expert_x_scale_shape, dtype=torch.float32, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8351326ce7c5..2162b37e4bc2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -866,10 +866,10 @@ def select_gemm_impl( experts = CutlassExpertsFp8( num_experts, - prepare_finalize.num_dispatchers(), moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + num_dispatchers=prepare_finalize.num_dispatchers(), use_batched_format=use_batched_format, ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 61667749a536..5a70f3a616c6 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -92,8 +92,6 @@ def __init__(self, tp_size=tp_size, prefix=f"{prefix}.experts") - self.tp_size = tp_size if tp_size is not None else 1 - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape @@ -101,11 +99,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - - if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) - return final_hidden_states.view(orig_shape) From a9b0730d1631e58d47bc90cfd73ddd3663e6fadc Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Jul 2025 00:09:37 +0000 Subject: [PATCH 76/77] remove bogus assert Signed-off-by: Bill Nell --- .../layers/fused_moe/deepep_ll_prepare_finalize.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 9ce94e045d42..44d0a2b18b1d 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -92,8 +92,6 @@ def _do_quant( assert isinstance(x, torch.Tensor) - assert not per_act_token_quant - num_experts, max_tokens, hidden_dim = x.size() # TODO (varun): Optimization - Use a batched version of quant From b37026d1350c626802aaee50740c4c60a4882894 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Jul 2025 02:14:35 +0000 Subject: [PATCH 77/77] scout fixes Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 5 +-- .../compressed_tensors_moe.py | 42 ++++++++++++------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 56abe700daf9..d889f740a0c4 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -41,10 +41,7 @@ def run_cutlass_moe_fp8( assert w2_scale is not None assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn - if expert_num_tokens is None: - assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1" - else: - assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1" + assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1" assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" assert w1_scale.dim() == 1 or w1_scale.size( 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 2162b37e4bc2..bbbec8d3c78a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -577,6 +577,7 @@ def select_gemm_impl( prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: + from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) @@ -584,20 +585,27 @@ def select_gemm_impl( logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - assert (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts) - - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( + ) + assert max_num_tokens_per_rank is not None - return BatchedTritonExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), - ) + return BatchedTritonExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN), + ) + else: + return TritonExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN), + ) def apply( self, @@ -859,6 +867,8 @@ def select_gemm_impl( use_batched_format = (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts) + num_dispatchers = prepare_finalize.num_dispatchers() + num_experts = (moe.num_local_experts if use_batched_format else moe.num_experts) @@ -869,11 +879,13 @@ def select_gemm_impl( moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - num_dispatchers=prepare_finalize.num_dispatchers(), + num_dispatchers=num_dispatchers, use_batched_format=use_batched_format, ) - self.disable_expert_map = not experts.supports_expert_map() + self.disable_expert_map = (num_dispatchers > 1 + or not experts.supports_expert_map()) + return experts def apply(