Skip to content

Commit c58accc

Browse files
ApsarasXyiz-liuYikun
authored
[Bugfix] Support Qwen3-MOE on aclgraph mode (#1381)
### What this PR does / why we need it? Fix the shape of the `npu_moe_init_routing` input parameters to support aclgraph mode on qwen3-moe In addition to this PR, resolving the `gatherv3` error might be necessary. See related PR #1297 #1446 Thanks to @yiz-liu for providing the idea ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tested on Qwen3-30B-A3B Closes: #1368 --------- Signed-off-by: ApsarasX <apsarax@outlook.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 14373f6 commit c58accc

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

tests/e2e/singlecard/test_aclgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tests.conftest import VllmRunner
3030
from tests.model_utils import check_outputs_equal
3131

32-
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
32+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "vllm-ascend/Qwen3-30B-A3B-Puring"]
3333

3434

3535
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",

vllm_ascend/ops/common_fused_moe.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,23 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21+
from vllm.config import CompilationLevel, get_current_vllm_config
2122
from vllm.model_executor.layers.fused_moe.layer import \
2223
UnquantizedFusedMoEMethod
2324

2425
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
2526
select_experts)
2627
from vllm_ascend.utils import is_310p
2728

29+
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
30+
31+
32+
def unquantized_fused_moe_init_func(self, *args, **kwargs):
33+
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
34+
vllm_config = get_current_vllm_config()
35+
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
36+
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager
37+
2838

2939
def forward_oot(
3040
self,
@@ -71,6 +81,10 @@ def forward_oot(
7181
expert_map=expert_map,
7282
apply_router_weight_on_input=apply_router_weight_on_input)
7383

84+
# If use aclgraph, we need to set max_num_tokens to make
85+
# the input shape of `npu_moe_init_routing` fixed
86+
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
87+
7488
return fused_experts(
7589
hidden_states=x,
7690
w1=layer.w13_weight,
@@ -79,7 +93,9 @@ def forward_oot(
7993
topk_ids=topk_ids,
8094
top_k=top_k,
8195
expert_map=expert_map,
82-
apply_router_weight_on_input=apply_router_weight_on_input)
96+
apply_router_weight_on_input=apply_router_weight_on_input,
97+
max_num_tokens=max_num_tokens)
8398

8499

100+
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
85101
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/ops/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ def fused_experts(
655655
top_k: int,
656656
expert_map: torch.Tensor = None,
657657
apply_router_weight_on_input: bool = False,
658+
max_num_tokens: Optional[int] = None,
658659
) -> torch.Tensor:
659660
"""
660661
Fused experts with top-k routing.
@@ -748,11 +749,12 @@ def fused_experts(
748749
dtype=torch.int32,
749750
device=device).view(top_k, -1).permute(
750751
1, 0).contiguous())
752+
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
751753
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
752754
hidden_states,
753755
row_idx=row_idx,
754756
expert_idx=topk_ids,
755-
active_num=num_tokens)
757+
active_num=active_num)
756758

757759
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
758760
expanded_expert_idx, num_experts)

0 commit comments

Comments
 (0)