Skip to content

Commit c8055da

Browse files
committed
feat: support CompressedTensorsW8A8Fp8MoECutlassMethod and CompressedTensorsW8A8Fp8MoECutlassMethod
1 parent e657c3e commit c8055da

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,10 @@ def apply(
633633
logical_replica_count: Optional[torch.Tensor] = None,
634634
) -> torch.Tensor:
635635
if enable_eplb:
636-
raise NotImplementedError(
637-
"EPLB not supported for "
638-
"`CompressedTensorsW8A8Fp8MoEMethod` yet.")
636+
assert expert_load_view is not None
637+
assert logical_to_physical_map is not None
638+
assert logical_replica_count is not None
639+
assert isinstance(layer, FusedMoE)
639640

640641
topk_weights, topk_ids = FusedMoE.select_experts(
641642
hidden_states=x,
@@ -649,6 +650,11 @@ def apply(
649650
scoring_func=scoring_func,
650651
e_score_correction_bias=e_score_correction_bias,
651652
indices_type=self.topk_indices_dtype,
653+
enable_eplb=enable_eplb,
654+
expert_map=expert_map,
655+
expert_load_view=expert_load_view,
656+
logical_to_physical_map=logical_to_physical_map,
657+
logical_replica_count=logical_replica_count,
652658
)
653659

654660
if self.rocm_aiter_moe_enabled:
@@ -913,9 +919,10 @@ def apply(
913919
logical_replica_count: Optional[torch.Tensor] = None,
914920
) -> torch.Tensor:
915921
if enable_eplb:
916-
raise NotImplementedError(
917-
"EPLB not supported for "
918-
"`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.")
922+
assert expert_load_view is not None
923+
assert logical_to_physical_map is not None
924+
assert logical_replica_count is not None
925+
assert isinstance(layer, FusedMoE)
919926

920927
topk_weights, topk_ids = FusedMoE.select_experts(
921928
hidden_states=x,
@@ -927,7 +934,12 @@ def apply(
927934
num_expert_group=num_expert_group,
928935
custom_routing_function=custom_routing_function,
929936
scoring_func=scoring_func,
930-
e_score_correction_bias=e_score_correction_bias)
937+
e_score_correction_bias=e_score_correction_bias,
938+
enable_eplb=enable_eplb,
939+
expert_map=expert_map,
940+
expert_load_view=expert_load_view,
941+
logical_to_physical_map=logical_to_physical_map,
942+
logical_replica_count=logical_replica_count)
931943

932944
a1_scale = layer.w13_input_scale
933945
a2_scale = layer.w2_input_scale

0 commit comments

Comments
 (0)