Skip to content

Commit fcb9f87

Browse files
authored
[Bugfix] Correct per_act_token in CompressedTensorsW8A8Fp8MoECutlassM… (#20937)
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent 3ed94f9 commit fcb9f87

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -929,10 +929,8 @@ def apply(
929929
scoring_func=scoring_func,
930930
e_score_correction_bias=e_score_correction_bias)
931931

932-
a1_scale = layer.w13_input_scale
933-
a2_scale = layer.w2_input_scale
934-
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
935-
a2_scale.numel() != 1 if a2_scale is not None else False)
932+
per_act_token = (
933+
self.input_quant.strategy == QuantizationStrategy.TOKEN)
936934

937935
if self.fused_experts is None:
938936
# If no modular kernel is provided, use cutlass_moe_fp8
@@ -950,8 +948,8 @@ def apply(
950948
expert_map=None if self.disable_expert_map else expert_map,
951949
w1_scale=layer.w13_weight_scale,
952950
w2_scale=layer.w2_weight_scale,
953-
a1_scale=a1_scale,
954-
a2_scale=a2_scale,
951+
a1_scale=layer.w13_input_scale,
952+
a2_scale=layer.w2_input_scale,
955953
)
956954
else:
957955
return self.fused_experts(

0 commit comments

Comments
 (0)