Skip to content

Commit 5b926ae

Browse files
author
yangcheng (AJ)
committed
add qwen3-moe optimization
Signed-off-by: yangcheng (AJ) <yangcheng104@huawei.com>
1 parent 53c2d58 commit 5b926ae

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

tests/e2e/singlecard/test_offline_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MODELS = [
3636
"Qwen/Qwen2.5-0.5B-Instruct",
3737
"Qwen/Qwen3-0.6B-Base",
38+
"Qwen/Qwen3-30B-A3B",
3839
]
3940
MULTIMODALITY_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]
4041

vllm_ascend/models/qwen3_moe.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,31 @@
1616
# Adapted from vllm/model_executor/models/qwen3_moe.py
1717
# This file is a part of the vllm-ascend project.
1818

19+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20+
21+
import torch
22+
import torch.distributed as dist
23+
import torch_npu
24+
import vllm
25+
import vllm.envs as envs
26+
from torch import nn
27+
from transformers import PretrainedConfig
28+
from vllm.attention import AttentionMetadata
29+
from vllm.distributed import (get_tensor_model_parallel_world_size,
30+
get_tp_group)
31+
from vllm.distributed.parallel_state import get_dp_group
32+
from vllm.forward_context import get_forward_context
33+
from vllm.model_executor.layers.linear import ReplicatedLinear
34+
35+
from vllm.model_executor.layers.quantization import QuantizationConfig
36+
37+
from vllm_ascend.ascend_config import get_ascend_config
38+
from vllm_ascend.distributed.parallel_state import get_ep_group
39+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
40+
1941
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
42+
from transformers import PretrainedConfig
43+
from vllm.model_executor.layers.quantization import QuantizationConfig
2044

2145

2246
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
@@ -33,3 +57,89 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
3357
"experts":
3458
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
3559
}
60+
61+
62+
class AscendQwen3MoeSparseMoeBlock(nn.Module):
63+
64+
top_k: int
65+
66+
def __init__(
67+
self,
68+
config: PretrainedConfig,
69+
quant_config: Optional[QuantizationConfig] = None,
70+
prefix: str = "",
71+
):
72+
super().__init__()
73+
self.tp_size = get_tensor_model_parallel_world_size()
74+
if self.tp_size > config.num_experts:
75+
raise ValueError(
76+
f"Tensor parallel size {self.tp_size} is greater than "
77+
f"the number of experts {config.num_experts}.")
78+
79+
ascend_config = get_ascend_config()
80+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
81+
self.enable_multistream_moe = \
82+
ascend_config.torchair_graph_config.enable_multistream_moe
83+
84+
self.gate = ReplicatedLinear(config.hidden_size,
85+
config.num_experts,
86+
bias=False,
87+
quant_config=None,
88+
prefix=f"{prefix}.gate")
89+
90+
self.experts = AscendFusedMoE(
91+
num_experts=config.num_experts,
92+
top_k=config.num_experts_per_tok,
93+
hidden_size=config.hidden_size,
94+
intermediate_size=config.moe_intermediate_size,
95+
reduce_results=False,
96+
renormalize=config.norm_topk_prob,
97+
quant_config=quant_config,
98+
prefix=f"{prefix}.experts")
99+
100+
101+
self.top_k = config.num_experts_per_tok
102+
103+
self.dp_size = get_dp_group().world_size
104+
105+
self.tp_group = get_tp_group().device_group
106+
self.tp_rank = get_tp_group().rank_in_group
107+
self.ep_group = get_ep_group()
108+
109+
self.params_dtype = torch.get_default_dtype()
110+
111+
def forward(
112+
self,
113+
hidden_states: torch.Tensor,
114+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
115+
if attn_metadata is None:
116+
attn_metadata = get_forward_context().attn_metadata
117+
# when profile runs, force experts to load balanced tokens
118+
# to avoid high memory consumption on a single rank.
119+
# TODO: need a better flag to indicate whether in profile run or not.
120+
if attn_metadata is None:
121+
# for profile run
122+
is_prefill = True
123+
enable_force_load_balance = True
124+
else:
125+
# is_prefill = attn_metadata.num_prefills > 0 is_prefill or
126+
enable_force_load_balance = False
127+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
128+
is_prefill = attn_metadata.with_prefill_across_dp
129+
130+
# router_logits: (num_tokens, n_experts)
131+
router_logits, _ = self.gate(hidden_states)
132+
133+
hidden_states = self.experts(
134+
hidden_states=hidden_states,
135+
router_logits=router_logits,
136+
is_prefill=is_prefill,
137+
top_k=self.top_k,
138+
enable_force_load_balance=enable_force_load_balance,
139+
shared_experts=None,
140+
)
141+
142+
return hidden_states
143+
144+
145+
vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock

0 commit comments

Comments
 (0)