Skip to content

Commit b8f883e

Browse files
minosfutureChen-zexi
authored andcommitted
[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE (vllm-project#20166)
Signed-off-by: Ming Yang <yming@meta.com>
1 parent 384481f commit b8f883e

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
constexpr uint64_t THREADS_PER_EXPERT = 512;
99

10-
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
10+
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
1111
int32_t* problem_sizes1,
1212
int32_t* problem_sizes2,
1313
int32_t* atomic_buffer,
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
6262
}
6363
}
6464

65-
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
65+
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
6666
const int32_t* __restrict__ expert_offsets,
6767
int32_t* input_permutation,
6868
int32_t* output_permutation,
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
103103

104104
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
105105
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
106-
static_cast<const uint32_t*>(topk_ids.data_ptr()),
106+
static_cast<const int32_t*>(topk_ids.data_ptr()),
107107
static_cast<int32_t*>(problem_sizes1.data_ptr()),
108108
static_cast<int32_t*>(problem_sizes2.data_ptr()),
109109
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
120120
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
121121
}
122122
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
123-
static_cast<const uint32_t*>(topk_ids.data_ptr()),
123+
static_cast<const int32_t*>(topk_ids.data_ptr()),
124124
static_cast<const int32_t*>(expert_offsets.data_ptr()),
125125
static_cast<int32_t*>(input_permutation.data_ptr()),
126126
static_cast<int32_t*>(output_permutation.data_ptr()),

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]:
7878
return self.max_num_tokens
7979

8080
def topk_indices_dtype(self) -> Optional[torch.dtype]:
81-
return torch.uint32
81+
return torch.int32
8282

8383
def num_dispatchers(self) -> int:
8484
return self.num_dispatchers_
@@ -100,7 +100,9 @@ def prepare(
100100
hidden_dim = a1.size(-1) # K
101101

102102
assert topk_ids.size(0) == num_tokens
103-
# assert expert_map is None, "NYI"
103+
assert expert_map is None, """with expert map, -1 id is used for
104+
non-local token; this causes error when casting ids to the
105+
topk_indices_dtype() uint32"""
104106

105107
# Is this always going to be a1.device?
106108
device = a1.device

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -929,23 +929,27 @@ def apply(
929929
num_expert_group=num_expert_group,
930930
custom_routing_function=custom_routing_function,
931931
scoring_func=scoring_func,
932-
e_score_correction_bias=e_score_correction_bias,
933-
indices_type=self.topk_indices_dtype,
934-
)
932+
e_score_correction_bias=e_score_correction_bias)
933+
934+
a1_scale = layer.w13_input_scale
935+
a2_scale = layer.w2_input_scale
936+
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
937+
a2_scale.numel() != 1 if a2_scale is not None else False)
935938

936939
return self.fused_experts(
937940
x,
938941
layer.w13_weight,
939942
layer.w2_weight,
940943
topk_weights,
941944
topk_ids,
945+
per_act_token=per_act_token,
942946
activation=activation,
943947
global_num_experts=global_num_experts,
944948
expert_map=None if self.disable_expert_map else expert_map,
945949
w1_scale=layer.w13_weight_scale,
946950
w2_scale=layer.w2_weight_scale,
947-
a1_scale=layer.w13_input_scale,
948-
a2_scale=layer.w2_input_scale,
951+
a1_scale=a1_scale,
952+
a2_scale=a2_scale,
949953
)
950954

951955

0 commit comments

Comments
 (0)