Skip to content

Commit 63944db

Browse files
weijinqian0weijinqian_v1harygo22whx-sjtuJC-ut0
authored
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo (#1547)
[Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo ## Introduction This PR introduces two key optimizations for MoE model performance: 1. **Efficient Token Dispatcher**: - Implements an optimized `alltoallv_seq` token dispatcher (adopted from NVIDIA Megatron and Ascend MindSpeed) - Significantly more efficient than current alltoall implementation when using token_permute/unpermute fusion - Enable with: `VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ=1` 2. **DBO Support for alltoallv_seq**: - Builds upon the `alltoallv_seq` dispatcher to support DBO (Dual Batch Overlap) - Enables overlapping of alltoallv communication during the prefilling stage - Enable with both: - `VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ=1` - `VLLM_ASCEND_ENABLE_DBO=1` ## Performance Improvements Testing on Qwen3-30B-A3B shows **nearly 2x throughput improvement** compared to the original alltoall implementation. ![](http://image.huawei.com/tiny-lts/v1/images/mdstorm/903e562dc8f8304ccab06a95774a25e8_1478x705.png) --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: songshanhu07 <1763685535@qq.com> Signed-off-by: duyangkai <duyangkai@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: yangkai <duyangkai@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: songshanhu07 <1763685535@qq.com>
1 parent f08283a commit 63944db

18 files changed

+2541
-183
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ torch-npu==2.5.1.post1.dev20250619
2828

2929
# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
3030
transformers<4.53.0
31+
pytest_mock

tests/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,30 @@ def test_models_distributed_DeepSeekV3_dbo():
154154
vllm_model.generate(example_prompts, sampling_params)
155155

156156

157+
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
158+
@patch.dict(os.environ, {
159+
"VLLM_ASCEND_ENABLE_DBO": "1",
160+
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"
161+
})
162+
def test_models_distributed_DeepSeekV3_alltoallv_dbo():
163+
example_prompts = ["The president of the United States is"] * 10
164+
dtype = "half"
165+
sampling_params = SamplingParams(max_tokens=30, temperature=0.0)
166+
with VllmRunner(
167+
"vllm-ascend/DeepSeek-V3-Pruning",
168+
dtype=dtype,
169+
tensor_parallel_size=4,
170+
distributed_executor_backend="mp",
171+
) as vllm_model:
172+
model_arch = 'DeepseekV3ForCausalLM'
173+
registed_models = ModelRegistry.models
174+
assert registed_models[
175+
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
176+
assert registed_models[
177+
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
178+
vllm_model.generate(example_prompts, sampling_params)
179+
180+
157181
def test_models_distributed_DeepSeek_W8A8():
158182
example_prompts = [
159183
"Hello, my name is",
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
18+
import importlib
19+
import unittest
20+
from unittest.mock import MagicMock, patch
21+
22+
import pytest
23+
import torch
24+
25+
from vllm_ascend.distributed.tensor_parallel import (
26+
_gather_along_first_dim, _gather_along_last_dim,
27+
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
28+
all_to_all_hp2sp, all_to_all_sp2hp)
29+
30+
31+
@pytest.fixture
32+
def test_tensor():
33+
return torch.randn(8, 16)
34+
35+
36+
@pytest.fixture
37+
def test_tensor_last_dim():
38+
return torch.randn(8, 16, 32)
39+
40+
41+
@pytest.fixture
42+
def mock_group():
43+
return MagicMock()
44+
45+
46+
@pytest.fixture(autouse=True)
47+
def mock_dist():
48+
with patch("torch.distributed") as mock:
49+
mock.get_world_size.return_value = 4
50+
mock.get_rank.return_value = 0
51+
yield mock
52+
53+
54+
class TestDistributedCommunication(unittest.TestCase):
55+
56+
@pytest.mark.parametrize("world_size", [1, 4])
57+
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
58+
world_size):
59+
"""test _gather_along_first_dim"""
60+
mock_dist.get_world_size.return_value = world_size
61+
62+
result = _gather_along_first_dim(test_tensor, mock_group)
63+
64+
if world_size == 1:
65+
self.assertEqual(result.shape, (8, 16))
66+
else:
67+
self.assertEqual(result.shape, (32, 16)) # 8*4=32
68+
69+
def test_gather_along_first_dim_unequal_split(self, test_tensor,
70+
mock_group):
71+
"""test unequal split"""
72+
output_split_sizes = [5, 10, 15, 2]
73+
result = _gather_along_first_dim(test_tensor, mock_group,
74+
output_split_sizes)
75+
self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32
76+
77+
@pytest.mark.parametrize("world_size", [1, 4])
78+
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group,
79+
mock_dist, world_size):
80+
"""test _gather_along_last_dim"""
81+
mock_dist.get_world_size.return_value = world_size
82+
83+
result = _gather_along_last_dim(test_tensor_last_dim, mock_group)
84+
85+
self.assertEqual(result.shape, (8, 16, 32 * world_size))
86+
87+
@pytest.mark.parametrize("input_shape,expected_shape", [
88+
((32, 16), (8, 16)),
89+
((40, 10), (10, 10)),
90+
])
91+
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape,
92+
expected_shape):
93+
input_tensor = torch.randn(*input_shape)
94+
result = _reduce_scatter_along_first_dim(input_tensor, mock_group)
95+
self.assertEqual(result.shape, expected_shape)
96+
97+
def test_reduce_scatter_along_last_dim(self, mock_group):
98+
input_tensor = torch.randn(8, 16, 32)
99+
result = _reduce_scatter_along_last_dim(input_tensor, mock_group)
100+
self.assertEqual(result.shape, (8, 16, 8))
101+
102+
@pytest.mark.parametrize("func,input_shape,expected_shape", [
103+
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
104+
(8, 16, 128)),
105+
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
106+
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
107+
(8, 16, 8)),
108+
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
109+
])
110+
def test_wrapper_functions(self, mock_group, func, input_shape,
111+
expected_shape):
112+
"""test wrapper funcs"""
113+
mod = importlib.import_module(
114+
'vllm_ascend.distributed.tensor_parallel')
115+
globals = mod.__dict__
116+
test_func = globals[func]
117+
input_tensor = torch.randn(*input_shape)
118+
result = test_func(input_tensor, mock_group)
119+
self.assertEqual(result.shape, expected_shape)
120+
121+
@pytest.mark.parametrize(
122+
"input_shape,output_shape",
123+
[
124+
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
125+
])
126+
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape):
127+
input_tensor = torch.randn(*input_shape)
128+
result = all_to_all_sp2hp(input_tensor, mock_group)
129+
self.assertEqual(result.shape, output_shape)
130+
131+
@pytest.mark.parametrize(
132+
"input_shape,output_shape",
133+
[
134+
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
135+
])
136+
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape):
137+
input_tensor = torch.randn(*input_shape)
138+
result = all_to_all_hp2sp(input_tensor, mock_group)
139+
self.assertEqual(result.shape, output_shape)

tests/ut/test_token_dispatcher.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
18+
import unittest
19+
20+
import pytest
21+
from pytest_mock import MockerFixture
22+
23+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
24+
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
25+
from vllm_ascend.utils import adapt_patch # noqa E402
26+
27+
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
28+
29+
adapt_patch(True)
30+
31+
32+
class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase):
33+
34+
@pytest.fixture
35+
def config(self):
36+
config = MoEDispatcherConfig()
37+
config.set_num_local_experts(2)
38+
config.set_num_moe_experts(4)
39+
config.set_moe_pad_expert_input_to_capacity(False)
40+
config.set_moe_expert_capacity_factor(None)
41+
config.set_moe_router_topk(2)
42+
config.set_moe_grouped_gemm(False)
43+
config.set_group_topk(0)
44+
config.set_num_groups(1)
45+
config.set_is_fused(False)
46+
return config.build()
47+
48+
def mock_ep_group(self, mocker):
49+
mock_group = mocker.MagicMock()
50+
mock_group.rank_in_group = 0
51+
mock_group.world_size = 2
52+
mock_group.device_group = "mock_group"
53+
return mock_group
54+
55+
@pytest.fixture
56+
def dispatcher(self, config, mocker: MockerFixture):
57+
mocker.patch(
58+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
59+
return_value=self.mock_ep_group(mocker))
60+
return MoEAlltoAllSeqOverLapDispatcher(config)
61+
62+
def test_initialization(self, dispatcher, config):
63+
self.assertEqual(dispatcher.num_local_experts,
64+
config.num_local_experts)
65+
self.assertEqual(dispatcher.num_experts, config.num_moe_experts)
66+
self.assertEqual(dispatcher.local_expert_indices, [0, 1])
67+
self.assertEqual(dispatcher.ep_rank, 0)
68+
self.assertEqual(dispatcher.ep_size, 2)
69+
self.assertIsNotNone(dispatcher.overlap_stream)

vllm_ascend/ascend_forward_context.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,26 @@
1010
from vllm.platforms import current_platform
1111

1212
import vllm_ascend.envs as envs
13+
import vllm_ascend.envs as envs_ascend
1314

1415

1516
class FusedMoEState(Enum):
1617
AllGather = 0
1718
All2All = 1
1819
MC2 = 2
1920
MC2_PREFILL = 3
21+
All2AllSeq = 4
2022

2123

2224
# TODO(zzzzwwjj): add soc_version to choose branch
2325
def get_fused_moe_state(ep_size: int, with_prefill: bool):
2426
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2
2527
if ep_size == 1:
2628
return FusedMoEState.AllGather
29+
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
30+
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
31+
return (FusedMoEState.All2AllSeq if
32+
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
2733
elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
2834
return FusedMoEState.MC2_PREFILL
2935
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
@@ -35,27 +41,30 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool):
3541

3642
@contextmanager
3743
def set_ascend_forward_context(
38-
attn_metadata: Any,
39-
vllm_config: VllmConfig,
40-
virtual_engine: int = 0,
41-
num_tokens: Optional[int] = None,
42-
num_tokens_across_dp: Optional[torch.Tensor] = None,
43-
with_prefill: bool = True,
44-
in_profile_run: bool = False,
45-
num_actual_tokens: Optional[int] = None):
44+
attn_metadata: Any,
45+
vllm_config: VllmConfig,
46+
virtual_engine: int = 0,
47+
num_tokens: Optional[int] = None,
48+
num_tokens_across_dp: Optional[torch.Tensor] = None,
49+
with_prefill: bool = True,
50+
in_profile_run: bool = False,
51+
num_actual_tokens: Optional[int] = None,
52+
):
4653
"""A context manager that stores the current forward context,
4754
can be attention metadata, etc.
4855
We add some additional param into forward_context.
4956
"""
50-
with set_forward_context(attn_metadata,
51-
vllm_config,
52-
virtual_engine=virtual_engine,
53-
num_tokens=num_tokens,
54-
num_tokens_across_dp=num_tokens_across_dp):
57+
with set_forward_context(
58+
attn_metadata,
59+
vllm_config,
60+
virtual_engine=virtual_engine,
61+
num_tokens=num_tokens,
62+
num_tokens_across_dp=num_tokens_across_dp,
63+
):
5564
forward_context = get_forward_context()
5665
forward_context.with_prefill = with_prefill
57-
ep_size = torch.distributed.get_world_size(
58-
) if vllm_config.parallel_config.enable_expert_parallel else 1
66+
ep_size = (torch.distributed.get_world_size() if
67+
vllm_config.parallel_config.enable_expert_parallel else 1)
5968

6069
fused_moe_state = get_fused_moe_state(ep_size, with_prefill)
6170

@@ -68,20 +77,21 @@ def set_ascend_forward_context(
6877
forward_context.capturing = False
6978

7079
if num_tokens is None and attn_metadata is not None:
71-
if hasattr(attn_metadata, 'num_actual_tokens'):
80+
if hasattr(attn_metadata, "num_actual_tokens"):
7281
# for v1 engine
7382
num_tokens = attn_metadata.num_actual_tokens
7483
else:
7584
# for v0 engine
76-
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
85+
num_tokens = (attn_metadata.num_prefill_tokens +
86+
attn_metadata.num_decode_tokens)
7787

7888
if num_actual_tokens is None:
7989
num_actual_tokens = num_tokens
8090

8191
dp_world_size = get_dp_group().world_size
8292
if dp_world_size > 1 and forward_context.dp_metadata is not None:
83-
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
84-
)
93+
max_tokens_across_dp = (
94+
forward_context.dp_metadata.max_tokens_across_dp_cpu.item())
8595
else:
8696
max_tokens_across_dp = num_tokens
8797

@@ -91,29 +101,33 @@ def set_ascend_forward_context(
91101
tp_world_size = get_tp_group().world_size
92102
world_size = torch.distributed.get_world_size()
93103
# NOTE: token num which need to pad to when mc2
94-
forward_context.padded_num_tokens = math.ceil(
95-
max_tokens_across_dp / tp_world_size) * tp_world_size
104+
forward_context.padded_num_tokens = (
105+
math.ceil(max_tokens_across_dp / tp_world_size) *
106+
tp_world_size)
96107
# NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs.
97-
forward_context.global_bs = math.ceil(
98-
max_tokens_across_dp / tp_world_size) * world_size
108+
forward_context.global_bs = (
109+
math.ceil(max_tokens_across_dp / tp_world_size) * world_size)
99110

100111
if fused_moe_state == FusedMoEState.MC2_PREFILL:
101112
chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
102113
forward_context.max_num_chunks = math.ceil(
103114
math.ceil(max_tokens_across_dp / tp_world_size) /
104115
chunk_size)
105116

106-
forward_context.global_bs = math.ceil(
117+
forward_context.global_bs = (math.ceil(
107118
math.ceil(max_tokens_across_dp / tp_world_size) /
108-
forward_context.max_num_chunks) * world_size
119+
forward_context.max_num_chunks) * world_size)
109120

110121
min_num_tokens = forward_context.max_num_chunks * tp_world_size
111-
forward_context.padded_num_tokens = math.ceil(
112-
max_tokens_across_dp / min_num_tokens) * min_num_tokens
113-
114-
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
115-
dtype=torch.bool,
116-
device=current_platform.device_type)
122+
forward_context.padded_num_tokens = (
123+
math.ceil(max_tokens_across_dp / min_num_tokens) *
124+
min_num_tokens)
125+
126+
mc2_mask = torch.zeros(
127+
forward_context.padded_num_tokens,
128+
dtype=torch.bool,
129+
device=current_platform.device_type,
130+
)
117131
mc2_mask[:num_actual_tokens] = True
118132
forward_context.mc2_mask = mc2_mask
119133

0 commit comments

Comments
 (0)