From 6d0e19a6e6d885a0820debbe8d91c3c40412df0e Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Tue, 24 Jun 2025 14:59:58 -0700 Subject: [PATCH 1/6] [Bugfix] Fix topk_ids indices_type for cutlass w8a8 fp8 moe Signed-off-by: Ming Yang --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5d7e00c2b81..4b54cd6878f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -883,9 +883,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - ) + e_score_correction_bias=e_score_correction_bias) return self.fused_experts( x, From 7c57bb0601026b99a39fdb9df220ed01a4924391 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Fri, 27 Jun 2025 16:06:47 -0700 Subject: [PATCH 2/6] Address comment: change moe_data topk_ids type back to int32_t Signed-off-by: Ming Yang --- csrc/quantization/cutlass_w8a8/moe/moe_data.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 32254641cc3..80c6589ab17 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -7,7 +7,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; -__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, +__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, @@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets( } } -__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids, +__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, @@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller( int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); @@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), num_experts); } compute_arg_sorts<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), From 376bcde5a953bddb9ac7f6ceb9f79d0f4c289cbc Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Thu, 3 Jul 2025 11:06:38 -0700 Subject: [PATCH 3/6] Address comment: add expert_map==None assertion in pplx_prepare_finalize Signed-off-by: Ming Yang --- .../layers/fused_moe/deepep_ll_prepare_finalize.py | 2 +- vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b315b4a97f0..a682b13e10e 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -65,7 +65,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_tokens_per_rank def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int64 + return torch.int32 def _do_quant( self, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 45e813287d3..f601458dc81 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -100,7 +100,9 @@ def prepare( hidden_dim = a1.size(-1) # K assert topk_ids.size(0) == num_tokens - # assert expert_map is None, "NYI" + assert expert_map is None, """with expert map, -1 id is used for + non-local token; this causes error when casting ids to the + topk_indices_dtype() uint32""" # Is this always going to be a1.device? device = a1.device From 6ca83a48d3a4830a93fce501d925761d15fb700d Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Thu, 3 Jul 2025 13:54:56 -0700 Subject: [PATCH 4/6] Fix after rebase: cutlass_moe_fp8 signature is changed Signed-off-by: Ming Yang --- .../compressed_tensors/compressed_tensors_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 4b54cd6878f..e754c5a16da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -885,19 +885,25 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + return self.fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, + per_act_token=per_act_token, activation=activation, global_num_experts=global_num_experts, expert_map=None if self.disable_expert_map else expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, ) From ec668c3190d74c8a0b5415ea85784a1cdb0703fd Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Thu, 3 Jul 2025 14:52:30 -0700 Subject: [PATCH 5/6] Address comment: change pplx id type to int32 Signed-off-by: Ming Yang --- vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index f601458dc81..10662410e36 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -81,7 +81,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.uint32 + return torch.int32 def prepare( self, From beeba9a19cb9796933b2a9dfb4435dbca1110fd2 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Mon, 7 Jul 2025 13:37:54 -0700 Subject: [PATCH 6/6] Address comment: revert type change for deepep_ll for now Signed-off-by: Ming Yang --- .../layers/fused_moe/deepep_ll_prepare_finalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index a682b13e10e..b315b4a97f0 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -65,7 +65,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_tokens_per_rank def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.int64 def _do_quant( self,