Skip to content

Commit 49e9771

Browse files
author
weijinqian_v1
committed
add moe_block: AscendSparseMoeBlock
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent d0bd006 commit 49e9771

File tree

3 files changed

+128
-1
lines changed

3 files changed

+128
-1
lines changed

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .moe_block import AscendSparseMoeBlock
1415

1516
ModelRegistry.register_model(
1617
"DeepSeekMTPModel",
@@ -20,6 +21,10 @@ def register_model():
2021
"Qwen2VLForConditionalGeneration",
2122
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
2223

24+
ModelRegistry.register_model(
25+
"Qwen3MoeSparseMoeBlock",
26+
"vllm_ascend.models.moe_block:AscendSparseMoeBlock")
27+
2328
if envs.USE_OPTIMIZED_MODEL:
2429
ModelRegistry.register_model(
2530
"Qwen2_5_VLForConditionalGeneration",

vllm_ascend/models/moe_block.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2023 The vLLM team.
3+
#
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
18+
from typing import Optional
19+
20+
import torch
21+
from torch import nn
22+
from vllm.attention import AttentionMetadata
23+
from vllm.distributed import (get_tensor_model_parallel_world_size,
24+
get_tp_group)
25+
from vllm.distributed.parallel_state import get_dp_group
26+
from vllm.forward_context import get_forward_context
27+
from vllm.model_executor.layers.linear import ReplicatedLinear
28+
29+
from vllm_ascend.ascend_config import get_ascend_config
30+
from vllm.distributed.parallel_state import get_ep_group
31+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
32+
33+
from transformers import PretrainedConfig
34+
from vllm.model_executor.layers.quantization import QuantizationConfig
35+
36+
37+
class AscendSparseMoeBlock(nn.Module):
38+
39+
top_k: int
40+
41+
def __init__(
42+
self,
43+
config: PretrainedConfig,
44+
quant_config: Optional[QuantizationConfig] = None,
45+
prefix: str = "",
46+
):
47+
super().__init__()
48+
self.tp_size = get_tensor_model_parallel_world_size()
49+
if self.tp_size > config.num_experts:
50+
raise ValueError(
51+
f"Tensor parallel size {self.tp_size} is greater than "
52+
f"the number of experts {config.num_experts}.")
53+
54+
ascend_config = get_ascend_config()
55+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
56+
self.enable_multistream_moe = \
57+
ascend_config.torchair_graph_config.enable_multistream_moe
58+
59+
self.gate = ReplicatedLinear(config.hidden_size,
60+
config.num_experts,
61+
bias=False,
62+
quant_config=None,
63+
prefix=f"{prefix}.gate")
64+
65+
self.experts = AscendFusedMoE(
66+
num_experts=config.num_experts,
67+
top_k=config.num_experts_per_tok,
68+
hidden_size=config.hidden_size,
69+
intermediate_size=config.moe_intermediate_size,
70+
reduce_results=False,
71+
renormalize=config.norm_topk_prob,
72+
quant_config=quant_config,
73+
prefix=f"{prefix}.experts")
74+
75+
self.top_k = config.num_experts_per_tok
76+
77+
self.dp_size = get_dp_group().world_size
78+
79+
self.tp_group = get_tp_group().device_group
80+
self.tp_rank = get_tp_group().rank_in_group
81+
self.ep_group = get_ep_group()
82+
83+
self.params_dtype = torch.get_default_dtype()
84+
85+
86+
def forward(
87+
self,
88+
hidden_states: torch.Tensor,
89+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
90+
if attn_metadata is None:
91+
attn_metadata = get_forward_context().attn_metadata
92+
# when profile runs, force experts to load balanced tokens
93+
# to avoid high memory consumption on a single rank.
94+
is_prefill = True
95+
if attn_metadata is None:
96+
# for profile run
97+
is_prefill = True
98+
enable_force_load_balance = True
99+
else:
100+
# is_prefill = attn_metadata.num_prefills > 0 is_prefill or
101+
enable_force_load_balance = False
102+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
103+
is_prefill = attn_metadata.with_prefill_across_dp
104+
105+
# router_logits: (num_tokens, n_experts)
106+
router_logits, _ = self.gate(hidden_states)
107+
108+
hidden_states = self.experts(
109+
hidden_states=hidden_states,
110+
router_logits=router_logits,
111+
is_prefill=is_prefill,
112+
top_k=self.top_k,
113+
enable_force_load_balance=enable_force_load_balance,
114+
shared_experts=None,
115+
)
116+
117+
return hidden_states

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor:
348348
def routing(self, probs):
349349
seq_length, bsz = probs.shape[:2]
350350
probs = probs.view(-1, self.config.num_moe_experts)
351+
if self.config.is_fused:
352+
score_function = "sigmoid"
353+
else:
354+
score_function = "softmax"
351355

352356
scores, routing_map, _, top_indices = topk_softmax_with_capacity(
353357
probs,
@@ -357,7 +361,8 @@ def routing(self, probs):
357361
group_topk=self.config.group_topk,
358362
num_groups=self.config.num_groups,
359363
expert_bias=self.config.expert_bias,
360-
scaling_factor=self.config.scaling_factor
364+
scaling_factor=self.config.scaling_factor,
365+
score_function=score_function
361366
)
362367
self.top_indices = top_indices
363368
return scores, routing_map

0 commit comments

Comments
 (0)