diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py new file mode 100644 index 0000000000..e05266ee05 --- /dev/null +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +""" +Compare the outputs of vLLM with and without aclgraph. + +Run `pytest tests/multicard/test_data_parallel.py`. +""" + +import os +import subprocess +import sys +from unittest.mock import patch + +import pytest + +MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"}) +def test_qwen3_moe_inference(model, max_tokens): + script = "examples/offline_data_parallel.py" + + env = os.environ.copy() + + cmd = [ + sys.executable, + script, + "--model", + model, + "--dp-size", + "2", + "--tp-size", + "2", + "--node-size", + "1", + "--node-rank", + "0", + "--trust-remote-code", + "--enforce-eager", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600) + output = proc.stdout.decode() + + print(output) + + assert "DP rank 0 needs to process" in output + assert "DP rank 1 needs to process" in output + assert "Generated text:" in output + assert proc.returncode == 0 diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52a7a..3ed9663206 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,8 +16,24 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. +from typing import Optional + +import torch +import vllm +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.ops.fused_moe import AscendFusedMoE + class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { @@ -33,3 +49,86 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + + +class AscendQwen3MoeSparseMoeBlock(nn.Module): + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + # is_prefill = attn_metadata.num_prefills > 0 + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = attn_metadata.with_prefill_across_dp + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None) + + return hidden_states + + +vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f8a4f5eb64..2635115658 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -118,9 +118,13 @@ def fused_experts_with_mc2( top_k: int, expert_map: torch.Tensor = None, moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None + shared_experts: Optional[Any] = None, + global_batch_size: int = 256, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - global_bs = 0 + + ep_group = get_ep_group().device_group + all_to_all_group_size = torch.distributed.get_world_size(ep_group) + global_bs = global_batch_size * all_to_all_group_size moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -132,11 +136,8 @@ def fused_experts_with_mc2( } rank = torch.distributed.get_rank() - quant_mode = 0 - ep_group = get_ep_group().device_group local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) tp_size = get_etp_group().world_size tp_rank = rank % tp_size @@ -204,7 +205,7 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, + "global_bs": global_bs, } tp_recv_counts = output[5] stage3_kwargs = { @@ -1037,7 +1038,8 @@ def apply( top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, - shared_experts=shared_experts) + shared_experts=shared_experts, + global_batch_size=self.global_batch_size) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight,