Skip to content

Commit 0e3fe89

Browse files
authored
Support Llama 4 for fused_marlin_moe (#20457)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 1caca5a commit 0e3fe89

File tree

6 files changed

+11
-17
lines changed

6 files changed

+11
-17
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
2424
topk_weights: torch.Tensor,
2525
topk_ids: torch.Tensor,
2626
quant_type_id: int,
27+
apply_router_weight_on_input: bool = False,
2728
global_num_experts: int = -1,
2829
expert_map: Optional[torch.Tensor] = None,
2930
global_scale1: Optional[torch.Tensor] = None,
@@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
149150
topk_weights,
150151
moe_block_size=block_size_m,
151152
top_k=topk,
152-
mul_topk_weights=False,
153+
mul_topk_weights=apply_router_weight_on_input,
153154
is_ep=expert_map is not None,
154155
b_q_type=quant_type,
155156
size_m=M,
@@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
182183
topk_weights,
183184
moe_block_size=block_size_m,
184185
top_k=1,
185-
mul_topk_weights=True,
186+
mul_topk_weights=not apply_router_weight_on_input,
186187
is_ep=expert_map is not None,
187188
b_q_type=quant_type,
188189
size_m=M * topk,
@@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
208209
topk_weights: torch.Tensor,
209210
topk_ids: torch.Tensor,
210211
quant_type_id: int,
212+
apply_router_weight_on_input: bool = False,
211213
global_num_experts: int = -1,
212214
global_scale1: Optional[torch.Tensor] = None,
213215
global_scale2: Optional[torch.Tensor] = None,

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,6 @@ def apply(
493493

494494
assert activation == "silu", "Only SiLU activation is supported."
495495

496-
if apply_router_weight_on_input:
497-
raise NotImplementedError(
498-
"Apply router weight on input is not supported for"
499-
"fused Marlin MoE method.")
500-
501496
topk_weights, topk_ids = FusedMoE.select_experts(
502497
hidden_states=x,
503498
router_logits=router_logits,
@@ -520,6 +515,7 @@ def apply(
520515
topk_weights,
521516
topk_ids,
522517
quant_type_id=self.quant_type.id,
518+
apply_router_weight_on_input=apply_router_weight_on_input,
523519
global_num_experts=global_num_experts,
524520
expert_map=expert_map,
525521
w1_zeros=layer.w13_qzeros,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def apply(
322322
global_scale1=layer.w13_weight_scale_2,
323323
global_scale2=layer.w2_weight_scale_2,
324324
quant_type_id=scalar_types.float4_e2m1f.id,
325+
apply_router_weight_on_input=apply_router_weight_on_input,
325326
global_num_experts=global_num_experts,
326327
expert_map=expert_map)
327328

@@ -669,8 +670,6 @@ def apply(
669670
if self.use_marlin:
670671
assert activation == "silu", (
671672
f"{activation} not supported for Marlin MoE.")
672-
assert not apply_router_weight_on_input, (
673-
"Apply router weight on input not supported for Marlin MoE.")
674673
return torch.ops.vllm.fused_marlin_moe(
675674
x,
676675
layer.w13_weight,
@@ -681,6 +680,7 @@ def apply(
681680
topk_weights,
682681
topk_ids,
683682
quant_type_id=scalar_types.float8_e4m3fn.id,
683+
apply_router_weight_on_input=apply_router_weight_on_input,
684684
global_num_experts=global_num_experts,
685685
expert_map=expert_map)
686686

@@ -1356,8 +1356,6 @@ def apply(
13561356

13571357
assert activation == "silu", (
13581358
f"{activation} not supported for Marlin MoE.")
1359-
assert not apply_router_weight_on_input, (
1360-
"Apply router weight on input not supported for Marlin MoE.")
13611359

13621360
topk_weights, topk_ids = FusedMoE.select_experts(
13631361
hidden_states=x,
@@ -1381,6 +1379,7 @@ def apply(
13811379
topk_weights,
13821380
topk_ids,
13831381
quant_type_id=self.quant_type.id,
1382+
apply_router_weight_on_input=apply_router_weight_on_input,
13841383
global_num_experts=global_num_experts,
13851384
expert_map=expert_map,
13861385
g_idx1=layer.w13_weight_g_idx,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,6 @@ def apply(
889889
elif self.use_marlin:
890890
assert activation == "silu", (
891891
f"{activation} not supported for Marlin MoE.")
892-
assert not apply_router_weight_on_input, (
893-
"Apply router weight on input not supported for Marlin MoE.")
894892
return torch.ops.vllm.fused_marlin_moe(
895893
x,
896894
layer.w13_weight,
@@ -901,6 +899,7 @@ def apply(
901899
topk_weights,
902900
topk_ids,
903901
quant_type_id=scalar_types.float8_e4m3fn.id,
902+
apply_router_weight_on_input=apply_router_weight_on_input,
904903
global_num_experts=global_num_experts,
905904
expert_map=expert_map)
906905
else:

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -645,10 +645,6 @@ def apply(
645645
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
646646

647647
assert activation == "silu", "Only SiLU activation is supported."
648-
if apply_router_weight_on_input:
649-
raise NotImplementedError(
650-
"Apply router weight on input is not supported for "
651-
"fused Marlin MoE method.")
652648

653649
topk_weights, topk_ids = FusedMoE.select_experts(
654650
hidden_states=x,
@@ -672,6 +668,7 @@ def apply(
672668
topk_weights,
673669
topk_ids,
674670
quant_type_id=self.quant_type.id,
671+
apply_router_weight_on_input=apply_router_weight_on_input,
675672
global_num_experts=global_num_experts,
676673
expert_map=expert_map,
677674
g_idx1=layer.w13_g_idx,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def apply(
700700
global_scale1=layer.w13_weight_scale_2,
701701
global_scale2=layer.w2_weight_scale_2,
702702
quant_type_id=scalar_types.float4_e2m1f.id,
703+
apply_router_weight_on_input=apply_router_weight_on_input,
703704
global_num_experts=global_num_experts,
704705
expert_map=expert_map)
705706

0 commit comments

Comments
 (0)