From 15671806fcd28d22c90673c2b846cd1e12df9776 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Tue, 24 Jun 2025 14:59:58 -0700 Subject: [PATCH 1/2] [Bugfix] Fix topk_ids indices_type for cutlass w8a8 fp8 moe Signed-off-by: Ming Yang --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7703b9e687c4..e624c58c02c1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -621,8 +621,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=torch.uint32) + e_score_correction_bias=e_score_correction_bias) return self.fused_experts( x, From e5ed82218abb34cedb3c951a624430e64ddc0c76 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Fri, 27 Jun 2025 16:06:47 -0700 Subject: [PATCH 2/2] Address comment: change moe_data topk_ids type back to int32_t Signed-off-by: Ming Yang --- csrc/quantization/cutlass_w8a8/moe/moe_data.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 32254641cc38..80c6589ab171 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -7,7 +7,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; -__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, +__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, @@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets( } } -__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids, +__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, @@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller( int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); @@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), num_experts); } compute_arg_sorts<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()),