From 4aa881457d1be93e86fd34a4e808ced5a3a2e585 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 11 Jul 2025 15:34:43 +0000 Subject: [PATCH 1/2] Fix a couple PPLX+CUTLASS bugs Signed-off-by: ElizaWszola --- .../layers/fused_moe/cutlass_moe.py | 1 + .../layers/fused_moe/pplx_prepare_finalize.py | 4 +- .../compressed_tensors_moe.py | 53 ++++++++++++------- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d6a30e34269..337ecb59572 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -320,6 +320,7 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) + topk_ids = topk_ids.view(dtype=torch.int32) in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 46f1231a617..9fb63569d37 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -204,7 +204,7 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids, + indices=topk_ids.view(dtype=torch.uint32), bound_m=bound_m, ) @@ -249,7 +249,7 @@ def finalize( topk_weights = torch.ones_like(topk_weights) self.a2a.combine(out_tokens=output, - indices=topk_ids, + indices=topk_ids.view(dtype=torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c17a390dba5..baf4fec3cc6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -737,10 +737,8 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) self.topk_indices_dtype = None - self.fused_experts = cutlass_moe_fp8 # type: ignore + self.fused_experts = None # type: ignore self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -936,21 +934,40 @@ def apply( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - per_act_token=per_act_token, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) + if self.fused_experts is None: + # If no modular kernel is provided, use cutlass_moe_fp8 + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + per_act_token=per_act_token, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + else: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): From 19e33fd590ed850c97a6c4398a10262cde593417 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 11 Jul 2025 15:43:31 +0000 Subject: [PATCH 2/2] Remove redundant typecast Signed-off-by: ElizaWszola --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 337ecb59572..d6a30e34269 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -320,7 +320,6 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) - topk_ids = topk_ids.view(dtype=torch.int32) in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable,