Skip to content

[0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo #1547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7ff288e
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jun 30, 2025
6d7b5b4
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
6a8e1a9
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
4805c5a
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
d68ce07
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
0aff693
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
f6ab19e
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
a94c094
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
91570d8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
e7c0d2d
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
47439e8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
cf3f1c8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
a4126f3
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
807aaf0
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
6f6efc1
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 2, 2025
305a0eb
handle conflict
Jul 8, 2025
5411ed6
add st:qwen3
Jul 8, 2025
3f88769
add st for moe token dispatcher
Jul 8, 2025
854c149
fix bug
harygo22 Jul 8, 2025
d0bd006
add st for moe token dispatcher
Jul 8, 2025
49e9771
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
a9bccf8
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
e31a7df
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
0a22312
[0.9.1][Perf] Optimize the number of rope-related index selections in…
whx-sjtu Jul 8, 2025
ee1dd49
[BUGFIX] FIX mtp accuraccy when temperture is not 0 (#1632)
JC-ut0 Jul 8, 2025
eef1093
add mc2 mask (#1642)
weiguihua2 Jul 8, 2025
eb54e22
[cherry-pick] static EPLB fix bug, add unit test to v0.9.1-dev (#1667)
songshanhu07 Jul 8, 2025
b02ad40
revert
harygo22 Jul 8, 2025
66807e0
fix bug
harygo22 Jul 8, 2025
d24758e
fix a bug
harygo22 Jul 8, 2025
d76c4fb
fix a bug
harygo22 Jul 8, 2025
f883902
ut test
harygo22 Jul 9, 2025
d5656f4
liscens & fix dsk dbo.
harygo22 Jul 9, 2025
df52070
handle conflict
Jul 9, 2025
adf3f74
handle code clean
Jul 9, 2025
5956ef0
handle code clean
Jul 9, 2025
af85566
handle code clean
Jul 9, 2025
d4ad734
handle code clean
Jul 9, 2025
847d52d
Merge branch 'v0.9.1-dev' into v0.9.1-dev
harygo22 Jul 10, 2025
3b7269a
fix comment
harygo22 Jul 10, 2025
a8b3e15
fix init
harygo22 Jul 10, 2025
d290b7d
remove files & move sparsemoeblock to ops
harygo22 Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/multicard/test_qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""
Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/multicard/test_data_parallel.py`.
"""

import os
import subprocess
import sys
from unittest.mock import patch

import pytest

MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@patch.dict(
os.environ, {
"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3",
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1",
"VLLM_ASCEND_ENABLE_DBO": "1"
})
def test_qwen3_moe_inference(model, max_tokens):
script = "examples/offline_data_parallel.py"

env = os.environ.copy()

cmd = [
sys.executable,
script,
"--model",
model,
"--dp-size",
"2",
"--tp-size",
"2",
"--node-size",
"1",
"--node-rank",
"0",
"--trust-remote-code",
"--enforce-eager",
]

print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600)
output = proc.stdout.decode()

print(output)

assert "DP rank 0 needs to process" in output
assert "DP rank 1 needs to process" in output
assert "Generated text:" in output
assert proc.returncode == 0
127 changes: 127 additions & 0 deletions tests/ut/test_distributed_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import importlib
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm_ascend.distributed.tensor_parallel import (
_gather_along_first_dim, _gather_along_last_dim,
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
all_to_all_hp2sp, all_to_all_sp2hp)


@pytest.fixture
def test_tensor():
return torch.randn(8, 16)


@pytest.fixture
def test_tensor_last_dim():
return torch.randn(8, 16, 32)


@pytest.fixture
def mock_group():
return MagicMock()


@pytest.fixture(autouse=True)
def mock_dist():
with patch("torch.distributed") as mock:
mock.get_world_size.return_value = 4
mock.get_rank.return_value = 0
yield mock


class TestDistributedCommunication:

@pytest.mark.parametrize("world_size", [1, 4])
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
world_size):
"""test _gather_along_first_dim"""
mock_dist.get_world_size.return_value = world_size

result = _gather_along_first_dim(test_tensor, mock_group)

if world_size == 1:
assert torch.equal(result, test_tensor)
else:
assert result.shape == (32, 16) # 8*4=32

def test_gather_along_first_dim_unequal_split(self, test_tensor,
mock_group):
"""test unequal split"""
output_split_sizes = [5, 10, 15, 2]
result = _gather_along_first_dim(test_tensor, mock_group,
output_split_sizes)
assert result.shape == (32, 16) # 5+10+15+2=32

@pytest.mark.parametrize("world_size", [1, 4])
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group,
mock_dist, world_size):
"""test _gather_along_last_dim"""
mock_dist.get_world_size.return_value = world_size

result = _gather_along_last_dim(test_tensor_last_dim, mock_group)

if world_size == 1:
assert torch.equal(result, test_tensor_last_dim)
else:
assert result.shape == (8, 16, 32 * world_size) # 8*4=32

@pytest.mark.parametrize("input_shape,expected_shape", [
((32, 16), (8, 16)),
((40, 10), (10, 10)),
])
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape,
expected_shape):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_first_dim(input_tensor, mock_group)
assert result.shape == expected_shape

def test_reduce_scatter_along_last_dim(self, mock_group):
input_tensor = torch.randn(8, 16, 32)
result = _reduce_scatter_along_last_dim(input_tensor, mock_group)
assert result.shape == (8, 16, 8) # 32/4=8

@pytest.mark.parametrize("func,input_shape,expected_shape", [
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
(8, 16, 128)),
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
(8, 16, 8)),
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
])
def test_wrapper_functions(self, mock_group, func, input_shape,
expected_shape):
"""test wrapper funcs"""
mod = importlib.import_module(
'vllm_ascend.distributed.tensor_parallel')
globals = mod.__dict__
test_func = globals[func]
input_tensor = torch.randn(*input_shape)
result = test_func(input_tensor, mock_group)
assert result.shape == expected_shape

@pytest.mark.parametrize(
"input_shape,output_shape",
[
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
])
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape):
input_tensor = torch.randn(*input_shape)
result = all_to_all_sp2hp(input_tensor, mock_group)
assert result.shape == output_shape

@pytest.mark.parametrize(
"input_shape,output_shape",
[
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
])
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape):
input_tensor = torch.randn(*input_shape)
result = all_to_all_hp2sp(input_tensor, mock_group)
assert result.shape == output_shape
62 changes: 62 additions & 0 deletions tests/ut/test_token_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.

import pytest
import torch
from pytest_mock import MockerFixture

from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig)
from vllm_ascend.utils import adapt_patch # noqa E402

import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa



adapt_patch(True)


class TestMoEAlltoAllSeqOverLapDispatcher:

@pytest.fixture
def config(self):
config = MoeDispatcherConfig()
config.set_num_local_experts(2)
config.set_num_moe_experts(4)
config.set_moe_pad_expert_input_to_capacity(False)
config.set_moe_expert_capacity_factor(None)
config.set_moe_router_topk(2)
config.set_moe_grouped_gemm(False)
config.set_group_topk(0)
config.set_num_groups(1)
config.set_is_fused(False)
return config.build()

def mock_ep_group(self, mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
return mock_group

@pytest.fixture
def dispatcher(self, config, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
return_value=self.mock_ep_group(mocker))
return MoEAlltoAllSeqOverLapDispatcher(config)

def test_initialization(self, dispatcher, config):
assert dispatcher.num_local_experts == config.num_local_experts
assert dispatcher.num_experts == config.num_moe_experts
assert dispatcher.local_expert_indices == [0, 1]
assert dispatcher.ep_rank == 0
assert dispatcher.ep_size == 2
assert dispatcher.overlap_stream is not None

def test_routing(self, dispatcher):
probs = torch.randn(4, 4) # 4 tokens, 4 experts
scores, routing_map = dispatcher.routing(probs)
assert scores.shape == (4, 4) # topk=2
assert routing_map.shape == (4, 4)
7 changes: 7 additions & 0 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,26 @@

import vllm_ascend.envs as envs

import vllm_ascend.envs as envs_ascend


class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
MC2_PREFILL = 3
All2AllSeq = 4


# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool):
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2
if ep_size == 1:
return FusedMoEState.AllGather
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
return FusedMoEState.All2AllSeq if (
ep_size < 16 or with_prefill) else FusedMoEState.MC2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is this restriction ep_size <16 ?

Copy link

@harygo22 harygo22 Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MC2 Dispatch/Combine is still faster than alltoall_seq in decoding stage. so when ep_size >= 16, use MC2 for better performance.

elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
return FusedMoEState.MC2_PREFILL
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
Expand Down
13 changes: 13 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from vllm_ascend.attention.utils import \
AscendCommonAttentionMetadata as CommonAttentionMetadata
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import get_graph_params

Expand Down Expand Up @@ -140,6 +141,18 @@ class AscendMetadata:

enable_dbo_across_dp: bool = False

def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMetadata"]:
"""Split metadata for multi-stream with AscendMetadata"""
from vllm_ascend.multistream.ms_split import model_input_split_v1_attn
return model_input_split_v1_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMetadata,
)


class AscendAttentionMetadataBuilder:

Expand Down
Loading