Skip to content

Commit 3b3b778

Browse files
authored
[Bugfix] Fix a couple PPLX+CUTLASS MoE bugs (#20825)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
1 parent 42d440c commit 3b3b778

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def prepare(
204204
out_expert_x_scale=expert_x_scale,
205205
dp_x=a1q,
206206
dp_x_scale=a1q_scale,
207-
indices=topk_ids,
207+
indices=topk_ids.view(dtype=torch.uint32),
208208
bound_m=bound_m,
209209
)
210210

@@ -249,7 +249,7 @@ def finalize(
249249
topk_weights = torch.ones_like(topk_weights)
250250

251251
self.a2a.combine(out_tokens=output,
252-
indices=topk_ids,
252+
indices=topk_ids.view(dtype=torch.uint32),
253253
weights=topk_weights,
254254
expert_y=fused_expert_output,
255255
bound_m=bound_m)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -737,10 +737,8 @@ def __init__(
737737
"For FP8 Fused MoE layer, we require either per tensor or "
738738
"channelwise, dynamic per token quantization.")
739739

740-
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
741-
cutlass_moe_fp8)
742740
self.topk_indices_dtype = None
743-
self.fused_experts = cutlass_moe_fp8 # type: ignore
741+
self.fused_experts = None # type: ignore
744742
self.disable_expert_map = False
745743

746744
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -936,21 +934,40 @@ def apply(
936934
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
937935
a2_scale.numel() != 1 if a2_scale is not None else False)
938936

939-
return self.fused_experts(
940-
x,
941-
layer.w13_weight,
942-
layer.w2_weight,
943-
topk_weights,
944-
topk_ids,
945-
per_act_token=per_act_token,
946-
activation=activation,
947-
global_num_experts=global_num_experts,
948-
expert_map=None if self.disable_expert_map else expert_map,
949-
w1_scale=layer.w13_weight_scale,
950-
w2_scale=layer.w2_weight_scale,
951-
a1_scale=a1_scale,
952-
a2_scale=a2_scale,
953-
)
937+
if self.fused_experts is None:
938+
# If no modular kernel is provided, use cutlass_moe_fp8
939+
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
940+
cutlass_moe_fp8)
941+
return cutlass_moe_fp8(
942+
x,
943+
layer.w13_weight,
944+
layer.w2_weight,
945+
topk_weights,
946+
topk_ids,
947+
per_act_token=per_act_token,
948+
activation=activation,
949+
global_num_experts=global_num_experts,
950+
expert_map=None if self.disable_expert_map else expert_map,
951+
w1_scale=layer.w13_weight_scale,
952+
w2_scale=layer.w2_weight_scale,
953+
a1_scale=a1_scale,
954+
a2_scale=a2_scale,
955+
)
956+
else:
957+
return self.fused_experts(
958+
x,
959+
layer.w13_weight,
960+
layer.w2_weight,
961+
topk_weights,
962+
topk_ids,
963+
activation=activation,
964+
global_num_experts=global_num_experts,
965+
expert_map=None if self.disable_expert_map else expert_map,
966+
w1_scale=layer.w13_weight_scale,
967+
w2_scale=layer.w2_weight_scale,
968+
a1_scale=layer.w13_input_scale,
969+
a2_scale=layer.w2_input_scale,
970+
)
954971

955972

956973
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):

0 commit comments

Comments
 (0)