Skip to content

Commit acc742d

Browse files
luccafongpatrickvonplaten
authored andcommitted
[Bugfix] Fix missing per_act_token parameter in compressed_tensors_moe (vllm-project#20509)
Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 6de4394 commit acc742d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def cutlass_moe_fp8(
322322
topk_ids: torch.Tensor,
323323
w1_scale: torch.Tensor,
324324
w2_scale: torch.Tensor,
325-
per_act_token: bool,
325+
per_act_token: Optional[bool] = None,
326326
activation: str = "silu",
327327
a1_scale: Optional[torch.Tensor] = None,
328328
a2_scale: Optional[torch.Tensor] = None,
@@ -366,6 +366,9 @@ def cutlass_moe_fp8(
366366
Returns:
367367
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
368368
"""
369+
if per_act_token is None:
370+
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
371+
a2_scale.numel() != 1 if a2_scale is not None else False)
369372
per_out_ch = w1_scale.numel() != w1_q.size(0)
370373

371374
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(

0 commit comments

Comments
 (0)