Skip to content

Commit e564470

Browse files
cxcxflyingchenxuevian
authored
[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it? moe support for llama4 and mllama4 in vllm-ascend ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? start sever: python -m vllm.entrypoints.openai.api_server --model /data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \ --max-num-seqs=256 \ --max-model-len=8192 \ --tensor-parallel-size=8 \ --block-size=128 \ --dtype bfloat16 \ --host=0.0.0.0 \ --port=8000 \ --gpu-memory-utilization=0.9 \ --trust-remote-code client: python online_server.py --model-path /data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct --image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip 7.242.108.253 --served-port 8000 --text "what is the content of this image?" result: {'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object': 'chat.completion', 'created': 1747056823, 'model': '/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'reasoning_content': None, 'content': 'The image depicts a tower, likely Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is white and has a distinctive shape, with a large sphere at the top and a long, thin spire extending from it. The branches of the cherry blossom tree are in the foreground, with pink flowers blooming on them. The background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:** White, spherical shape at the top, long thin spire\n', 'tool_calls': []}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}], 'usage': {'prompt_tokens': 2340, 'total_tokens': 2440, 'completion_tokens': 100, 'prompt_tokens_details': None}, 'prompt_logprobs': None} Signed-off-by: chenxu <chenxu68@huawei.com> Co-authored-by: chenxu <chenxu68@huawei.com> Co-authored-by: evian <eviantai@u.nus.edu>
1 parent 217211d commit e564470

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

vllm_ascend/attention/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ def __init__(
708708
blocksparse_params: Optional[Dict[str, Any]] = None,
709709
logits_soft_cap: Optional[float] = None,
710710
attn_type: str = AttentionType.DECODER,
711+
use_irope: bool = False,
711712
) -> None:
712713
self.num_heads = num_heads
713714
self.head_size = head_size

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __init__(
174174
blocksparse_params: Optional[Dict[str, Any]] = None,
175175
logits_soft_cap: Optional[float] = None,
176176
attn_type: str = AttentionType.DECODER,
177+
use_irope: bool = False,
177178
) -> None:
178179
self.num_heads = num_heads
179180
self.head_size = head_size

vllm_ascend/ops/common_fused_moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,15 @@ def forward_oot(
5555
e_score_correction_bias=e_score_correction_bias,
5656
)
5757

58-
return fused_experts(hidden_states=x,
59-
w1=layer.w13_weight,
60-
w2=layer.w2_weight,
61-
topk_weights=topk_weights,
62-
topk_ids=topk_ids,
63-
top_k=top_k,
64-
expert_map=expert_map)
58+
return fused_experts(
59+
hidden_states=x,
60+
w1=layer.w13_weight,
61+
w2=layer.w2_weight,
62+
topk_weights=topk_weights,
63+
topk_ids=topk_ids,
64+
top_k=top_k,
65+
expert_map=expert_map,
66+
apply_router_weight_on_input=apply_router_weight_on_input)
6567

6668

6769
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/ops/fused_moe.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def fused_experts(
153153
topk_ids: torch.Tensor,
154154
top_k: int,
155155
expert_map: torch.Tensor = None,
156+
apply_router_weight_on_input: bool = False,
156157
) -> torch.Tensor:
157158
"""
158159
Fused experts with top-k routing.
@@ -191,6 +192,15 @@ def fused_experts(
191192
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
192193
# ], "Only float32, float16, and bfloat16 are supported"
193194

195+
if apply_router_weight_on_input:
196+
assert (topk_weights.dim() == 2
197+
), "`topk_weights` should be in shape (num_tokens, topk)"
198+
_, topk = topk_weights.shape
199+
assert (
200+
topk == 1
201+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
202+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
203+
194204
if expert_map is not None:
195205
# Generate token indices and flatten
196206
token_indices = (torch.arange(num_tokens,
@@ -292,14 +302,16 @@ def fused_experts(
292302
torch.zeros_like(weighted_down_out)).to(dtype)
293303
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
294304
else:
305+
scales = torch.ones_like(
306+
topk_weights) if apply_router_weight_on_input else topk_weights
295307
# TODO: Reorder device memory 2 times here, replace the current
296308
# implementation here when suitable operators become available.
297309
final_hidden_states = torch_npu.npu_moe_finalize_routing(
298310
down_out_list,
299311
skip1=None,
300312
skip2=None,
301313
bias=None,
302-
scales=topk_weights,
314+
scales=scales,
303315
expanded_src_to_dst_row=expanded_row_idx,
304316
export_for_source_row=topk_ids,
305317
)
@@ -366,9 +378,6 @@ def select_experts(
366378
Raises:
367379
ValueError: If an unsupported scoring function is provided.
368380
"""
369-
if custom_routing_function is not None:
370-
raise NotImplementedError(
371-
"Custom routing function is not supported now")
372381

373382
if scoring_func == "softmax":
374383
# NOTE: vLLM use dtype=torch.float here
@@ -405,9 +414,18 @@ def select_experts(
405414
k=top_k,
406415
dim=-1,
407416
sorted=False)
408-
else:
417+
elif custom_routing_function is None:
409418
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
410419
topk_weights = topk_weights.to(hidden_states.dtype)
420+
else:
421+
topk_weights, topk_ids = custom_routing_function(
422+
hidden_states=hidden_states,
423+
gating_output=router_logits,
424+
topk=top_k,
425+
renormalize=renormalize)
426+
# Required by npu_moe_init_routing
427+
topk_ids = topk_ids.to(torch.int32)
428+
return topk_weights, topk_ids
411429

412430
# Required by npu_moe_init_routing
413431
topk_ids = topk_ids.to(torch.int32)

0 commit comments

Comments
 (0)