Skip to content

Commit 1fce70a

Browse files
authored
[Model] Support common fused moe ops for moe model, such as Qwen3Moe (#709)
vllm-ascend now only support moe for deepseek. We should add common moe support back Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 40bd602 commit 1fce70a

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

vllm_ascend/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch_npu # noqa: F401
2020

2121
import vllm_ascend.ops.activation # noqa
22+
import vllm_ascend.ops.common_fused_moe # noqa
2223
import vllm_ascend.ops.fused_moe # noqa
2324
import vllm_ascend.ops.layernorm # noqa
2425
import vllm_ascend.ops.rotary_embedding # noqa

vllm_ascend/ops/common_fused_moe.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Callable, Optional
19+
20+
import torch
21+
from vllm.model_executor.layers.fused_moe.layer import \
22+
UnquantizedFusedMoEMethod
23+
24+
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
25+
26+
27+
def forward_oot(
28+
self,
29+
layer: torch.nn.Module,
30+
x: torch.Tensor,
31+
use_grouped_topk: bool,
32+
top_k: int,
33+
router_logits: torch.Tensor,
34+
renormalize: bool,
35+
topk_group: Optional[int] = None,
36+
num_expert_group: Optional[int] = None,
37+
custom_routing_function: Optional[Callable] = None,
38+
scoring_func: str = "softmax",
39+
e_score_correction_bias: Optional[torch.Tensor] = None,
40+
global_num_experts: Optional[int] = None,
41+
expert_map: Optional[torch.Tensor] = None,
42+
apply_router_weight_on_input: bool = False,
43+
activation: str = "silu",
44+
) -> torch.Tensor:
45+
topk_weights, topk_ids = select_experts(
46+
hidden_states=x,
47+
router_logits=router_logits,
48+
top_k=top_k,
49+
use_grouped_topk=use_grouped_topk,
50+
renormalize=renormalize,
51+
topk_group=topk_group,
52+
num_expert_group=num_expert_group,
53+
custom_routing_function=custom_routing_function,
54+
scoring_func=scoring_func,
55+
e_score_correction_bias=e_score_correction_bias,
56+
)
57+
58+
return fused_experts(hidden_states=x,
59+
w1=layer.w13_weight,
60+
w2=layer.w2_weight,
61+
topk_weights=topk_weights,
62+
topk_ids=topk_ids,
63+
top_k=top_k,
64+
expert_map=expert_map)
65+
66+
67+
UnquantizedFusedMoEMethod.forward_oot = forward_oot

0 commit comments

Comments
 (0)