Skip to content

add qwen3-moe optimization #1441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/e2e/multicard/test_qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""
Compare the outputs of vLLM with and without aclgraph.

Run `pytest tests/multicard/test_data_parallel.py`.
"""

import os
import subprocess
import sys
from unittest.mock import patch

import pytest

MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
def test_qwen3_moe_inference(model, max_tokens):
script = "examples/offline_data_parallel.py"

env = os.environ.copy()

cmd = [
sys.executable,
script,
"--model",
model,
"--dp-size",
"2",
"--tp-size",
"2",
"--node-size",
"1",
"--node-rank",
"0",
"--trust-remote-code",
"--enforce-eager",
]

print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600)
output = proc.stdout.decode()

print(output)

assert "DP rank 0 needs to process" in output
assert "DP rank 1 needs to process" in output
assert "Generated text:" in output
assert proc.returncode == 0
99 changes: 99 additions & 0 deletions vllm_ascend/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,24 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.

from typing import Optional

import torch
import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE


class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
Expand All @@ -33,3 +49,86 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}


class AscendQwen3MoeSparseMoeBlock(nn.Module):
Copy link
Collaborator

@Yikun Yikun Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k: int

def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe

self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")

self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")

self.top_k = config.num_experts_per_tok

self.dp_size = get_dp_group().world_size

self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()

self.params_dtype = torch.get_default_dtype()

def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None:
# for profile run
is_prefill = True
enable_force_load_balance = True
else:
# is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = attn_metadata.with_prefill_across_dp

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None)

return hidden_states


vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock
16 changes: 9 additions & 7 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,13 @@
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None
shared_experts: Optional[Any] = None,
global_batch_size: int = 256,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
global_bs = 0

ep_group = get_ep_group().device_group
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
global_bs = global_batch_size * all_to_all_group_size

Check warning on line 127 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L125-L127

Added lines #L125 - L127 were not covered by tests
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
Expand All @@ -132,11 +136,8 @@
}

rank = torch.distributed.get_rank()

quant_mode = 0
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)

tp_size = get_etp_group().world_size
tp_rank = rank % tp_size
Expand Down Expand Up @@ -204,7 +205,7 @@
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"global_bs": global_bs,
}
tp_recv_counts = output[5]
stage3_kwargs = {
Expand Down Expand Up @@ -1037,7 +1038,8 @@
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts)
shared_experts=shared_experts,
global_batch_size=self.global_batch_size)
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down
Loading