Skip to content

Commit 31b96d1

Browse files
authored
Support Llama 4 for cutlass_moe_fp4 (#20453)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent e59ba9e commit 31b96d1

File tree

3 files changed

+80
-74
lines changed

3 files changed

+80
-74
lines changed

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,13 +411,23 @@ def cutlass_moe_fp8(
411411
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
412412

413413

414-
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
415-
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
416-
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
417-
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
418-
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
419-
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
420-
device: torch.device):
414+
def cutlass_moe_fp4(a: torch.Tensor,
415+
a1_gscale: torch.Tensor,
416+
w1_fp4: torch.Tensor,
417+
w1_blockscale: torch.Tensor,
418+
w1_alphas: torch.Tensor,
419+
a2_gscale: torch.Tensor,
420+
w2_fp4: torch.Tensor,
421+
w2_blockscale: torch.Tensor,
422+
w2_alphas: torch.Tensor,
423+
topk_weights: torch.Tensor,
424+
topk_ids: torch.Tensor,
425+
m: int,
426+
n: int,
427+
k: int,
428+
e: int,
429+
device: torch.device,
430+
apply_router_weight_on_input: bool = False):
421431
"""
422432
MoE implementation for FP4 Inputs
423433
@@ -480,6 +490,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
480490
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
481491
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
482492

493+
if apply_router_weight_on_input:
494+
# TODO: this only works for topK=1, will need to update for topK>1
495+
assert num_topk == 1, \
496+
"apply_router_weight_on_input is only implemented for topk=1"
497+
a.mul_(topk_weights.to(out_dtype))
498+
483499
# problem shapes should have [m, n, k]
484500
# Note that problem sizes are based on logical number of elements.
485501
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
@@ -517,8 +533,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
517533
del int_fp4, int_blockscale
518534

519535
c2 = ops.shuffle_rows(c2, c_map)
520-
out = (c2.view(m, num_topk, k) *
521-
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
536+
if not apply_router_weight_on_input:
537+
out = (c2.view(m, num_topk, k) *
538+
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
539+
else:
540+
out = c2.view(m, num_topk, k).sum(dim=1)
522541
return out.to(dtype=out_dtype)
523542

524543

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def apply(
295295
if enable_eplb:
296296
raise NotImplementedError("EPLB not supported for "
297297
"`CompressedTensorsW4A4MoeMethod` yet.")
298+
assert activation == "silu", "Only SiLU activation is supported."
298299

299300
topk_weights, topk_ids = FusedMoE.select_experts(
300301
hidden_states=x,
@@ -326,10 +327,6 @@ def apply(
326327
global_num_experts=global_num_experts,
327328
expert_map=expert_map)
328329

329-
assert activation == "silu", "Only SiLU activation is supported."
330-
assert not apply_router_weight_on_input, (
331-
"Router weight on input is not "
332-
"supported for CompressedTensorsW4A4MoeMethod.")
333330
assert expert_map is None, ("Expert Parallelism / expert_map "
334331
"is currently not supported for "
335332
"CompressedTensorsW4A4MoeMethod.")
@@ -339,22 +336,25 @@ def apply(
339336

340337
# Cutlass moe takes in activations in BF16/Half precision
341338
# and fp4 quantized weights loaded from the checkpoint
342-
return cutlass_moe_fp4(a=x,
343-
w1_fp4=layer.w13_weight,
344-
w1_blockscale=layer.w13_blockscale_swizzled,
345-
w1_alphas=layer.g1_alphas,
346-
w2_fp4=layer.w2_weight,
347-
w2_blockscale=layer.w2_blockscale_swizzled,
348-
w2_alphas=layer.g2_alphas,
349-
topk_weights=topk_weights,
350-
topk_ids=topk_ids,
351-
m=x.shape[0],
352-
n=layer.w2_weight.shape[2] * 2,
353-
k=x.shape[1],
354-
e=layer.w13_weight.shape[0],
355-
a1_gscale=layer.w13_input_scale_quant,
356-
a2_gscale=layer.w2_input_scale_quant,
357-
device=x.device).to(x.dtype)
339+
return cutlass_moe_fp4(
340+
a=x,
341+
w1_fp4=layer.w13_weight,
342+
w1_blockscale=layer.w13_blockscale_swizzled,
343+
w1_alphas=layer.g1_alphas,
344+
w2_fp4=layer.w2_weight,
345+
w2_blockscale=layer.w2_blockscale_swizzled,
346+
w2_alphas=layer.g2_alphas,
347+
topk_weights=topk_weights,
348+
topk_ids=topk_ids,
349+
m=x.shape[0],
350+
n=layer.w2_weight.shape[2] * 2,
351+
k=x.shape[1],
352+
e=layer.w13_weight.shape[0],
353+
a1_gscale=layer.w13_input_scale_quant,
354+
a2_gscale=layer.w2_input_scale_quant,
355+
device=x.device,
356+
apply_router_weight_on_input=apply_router_weight_on_input).to(
357+
x.dtype)
358358

359359

360360
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -673,21 +673,21 @@ def apply(
673673
if enable_eplb:
674674
raise NotImplementedError(
675675
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
676+
assert activation == "silu", "Only SiLU activation is supported."
676677

677-
if self.use_marlin:
678-
topk_weights, topk_ids = FusedMoE.select_experts(
679-
hidden_states=x,
680-
router_logits=router_logits,
681-
use_grouped_topk=use_grouped_topk,
682-
top_k=top_k,
683-
renormalize=renormalize,
684-
topk_group=topk_group,
685-
num_expert_group=num_expert_group,
686-
custom_routing_function=custom_routing_function,
687-
scoring_func=scoring_func,
688-
e_score_correction_bias=e_score_correction_bias,
689-
)
678+
topk_weights, topk_ids = FusedMoE.select_experts(
679+
hidden_states=x,
680+
router_logits=router_logits,
681+
use_grouped_topk=use_grouped_topk,
682+
top_k=top_k,
683+
renormalize=renormalize,
684+
topk_group=topk_group,
685+
num_expert_group=num_expert_group,
686+
custom_routing_function=custom_routing_function,
687+
scoring_func=scoring_func,
688+
e_score_correction_bias=e_score_correction_bias)
690689

690+
if self.use_marlin:
691691
return torch.ops.vllm.fused_marlin_moe(
692692
x,
693693
layer.w13_weight,
@@ -704,44 +704,31 @@ def apply(
704704
global_num_experts=global_num_experts,
705705
expert_map=expert_map)
706706

707-
assert activation == "silu", "Only SiLU activation is supported."
708-
assert not apply_router_weight_on_input, (
709-
"Router weight on input is not "
710-
"supported for ModelOptNvFp4FusedMoE.")
711707
assert expert_map is None, ("Expert Parallelism / expert_map "
712708
"is currently not supported for "
713709
"ModelOptNvFp4FusedMoE.")
714710

715-
topk_weights, topk_ids = FusedMoE.select_experts(
716-
hidden_states=x,
717-
router_logits=router_logits,
718-
use_grouped_topk=use_grouped_topk,
719-
top_k=top_k,
720-
renormalize=renormalize,
721-
topk_group=topk_group,
722-
num_expert_group=num_expert_group,
723-
custom_routing_function=custom_routing_function,
724-
scoring_func=scoring_func,
725-
e_score_correction_bias=e_score_correction_bias)
726-
727711
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
728712
cutlass_moe_fp4)
729713

730714
# Cutlass moe takes in activations in BF16/Half precision
731715
# and fp4 quantized weights loaded from the checkpoint
732-
return cutlass_moe_fp4(a=x,
733-
w1_fp4=layer.w13_weight,
734-
w1_blockscale=layer.w13_blockscale_swizzled,
735-
w1_alphas=layer.g1_alphas,
736-
w2_fp4=layer.w2_weight,
737-
w2_blockscale=layer.w2_blockscale_swizzled,
738-
w2_alphas=layer.g2_alphas,
739-
topk_weights=topk_weights,
740-
topk_ids=topk_ids,
741-
m=x.shape[0],
742-
n=layer.w2_weight.shape[2] * 2,
743-
k=x.shape[1],
744-
e=layer.w13_weight.shape[0],
745-
a1_gscale=layer.w13_input_scale_quant,
746-
a2_gscale=layer.w2_input_scale_quant,
747-
device=x.device).to(x.dtype)
716+
return cutlass_moe_fp4(
717+
a=x,
718+
w1_fp4=layer.w13_weight,
719+
w1_blockscale=layer.w13_blockscale_swizzled,
720+
w1_alphas=layer.g1_alphas,
721+
w2_fp4=layer.w2_weight,
722+
w2_blockscale=layer.w2_blockscale_swizzled,
723+
w2_alphas=layer.g2_alphas,
724+
topk_weights=topk_weights,
725+
topk_ids=topk_ids,
726+
m=x.shape[0],
727+
n=layer.w2_weight.shape[2] * 2,
728+
k=x.shape[1],
729+
e=layer.w13_weight.shape[0],
730+
a1_gscale=layer.w13_input_scale_quant,
731+
a2_gscale=layer.w2_input_scale_quant,
732+
device=x.device,
733+
apply_router_weight_on_input=apply_router_weight_on_input).to(
734+
x.dtype)

0 commit comments

Comments
 (0)