-
Notifications
You must be signed in to change notification settings - Fork 257
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo #1547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ganyi1996ppo
merged 68 commits into
vllm-project:v0.9.1-dev
from
weijinqian0:v0.9.1-dev
Jul 14, 2025
Merged
Changes from all commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
7ff288e
[Feature]Moe alltoallv communication optimization for unquantized RL …
6d7b5b4
[Feature]Moe alltoallv communication optimization for unquantized RL …
6a8e1a9
[Feature]Moe alltoallv communication optimization for unquantized RL …
4805c5a
[Feature]Moe alltoallv communication optimization for unquantized RL …
d68ce07
[Feature]Moe alltoallv communication optimization for unquantized RL …
0aff693
[Feature]Moe alltoallv communication optimization for unquantized RL …
f6ab19e
[Feature]Moe alltoallv communication optimization for unquantized RL …
a94c094
[Feature]Moe alltoallv communication optimization for unquantized RL …
91570d8
[Feature]Moe alltoallv communication optimization for unquantized RL …
e7c0d2d
[Feature]Moe alltoallv communication optimization for unquantized RL …
47439e8
[Feature]Moe alltoallv communication optimization for unquantized RL …
cf3f1c8
[Feature]Moe alltoallv communication optimization for unquantized RL …
a4126f3
[Feature]Moe alltoallv communication optimization for unquantized RL …
807aaf0
[Feature]Moe alltoallv communication optimization for unquantized RL …
6f6efc1
[Feature]Moe alltoallv communication optimization for unquantized RL …
305a0eb
handle conflict
5411ed6
add st:qwen3
3f88769
add st for moe token dispatcher
854c149
fix bug
harygo22 d0bd006
add st for moe token dispatcher
49e9771
add moe_block: AscendSparseMoeBlock
a9bccf8
add moe_block: AscendSparseMoeBlock
e31a7df
add moe_block: AscendSparseMoeBlock
0a22312
[0.9.1][Perf] Optimize the number of rope-related index selections in…
whx-sjtu ee1dd49
[BUGFIX] FIX mtp accuraccy when temperture is not 0 (#1632)
JC-ut0 eef1093
add mc2 mask (#1642)
weiguihua2 eb54e22
[cherry-pick] static EPLB fix bug, add unit test to v0.9.1-dev (#1667)
songshanhu07 b02ad40
revert
harygo22 66807e0
fix bug
harygo22 d24758e
fix a bug
harygo22 d76c4fb
fix a bug
harygo22 f883902
ut test
harygo22 d5656f4
liscens & fix dsk dbo.
harygo22 df52070
handle conflict
adf3f74
handle code clean
5956ef0
handle code clean
af85566
handle code clean
d4ad734
handle code clean
847d52d
Merge branch 'v0.9.1-dev' into v0.9.1-dev
harygo22 3b7269a
fix comment
harygo22 deb4319
handle code clean
8b369df
handle code conflict
a8b3e15
fix init
harygo22 d290b7d
remove files & move sparsemoeblock to ops
harygo22 2c102d3
remove test
harygo22 565fa2d
clean code
harygo22 f980ad0
fix clean code
harygo22 969ee25
typo
harygo22 a70be9a
renaming cuda sync point
harygo22 1f09708
Merge branch 'v0.9.1-dev' of https://github.com/weijinqian0/vllm-asce…
402f889
handle code clean
141407d
handle code clean
b1d7305
handle code clean
62cebe1
handle clean code
80b1d0d
handle clean code
e87df11
handle clean code
267db60
handle clean code
b0572c8
handle clean code
e4f1050
handle clean code
eaed83d
handle clean code
b97baf4
handle clean code
d232d49
handle clean code
1e435e6
handle clean code
8effdd0
handle clean code
a8136b7
handle code clean
d29deae
handle code conflict
94b7b5b
handle code clean
c2f670d
fix header
harygo22 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# Copyright 2023 The vLLM team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# This file is a part of the vllm-ascend project. | ||
|
||
import importlib | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm_ascend.distributed.tensor_parallel import ( | ||
_gather_along_first_dim, _gather_along_last_dim, | ||
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, | ||
all_to_all_hp2sp, all_to_all_sp2hp) | ||
|
||
|
||
@pytest.fixture | ||
def test_tensor(): | ||
return torch.randn(8, 16) | ||
|
||
|
||
@pytest.fixture | ||
def test_tensor_last_dim(): | ||
return torch.randn(8, 16, 32) | ||
|
||
|
||
@pytest.fixture | ||
def mock_group(): | ||
return MagicMock() | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_dist(): | ||
with patch("torch.distributed") as mock: | ||
mock.get_world_size.return_value = 4 | ||
mock.get_rank.return_value = 0 | ||
yield mock | ||
|
||
|
||
class TestDistributedCommunication(unittest.TestCase): | ||
|
||
@pytest.mark.parametrize("world_size", [1, 4]) | ||
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, | ||
world_size): | ||
"""test _gather_along_first_dim""" | ||
mock_dist.get_world_size.return_value = world_size | ||
|
||
result = _gather_along_first_dim(test_tensor, mock_group) | ||
|
||
if world_size == 1: | ||
self.assertEqual(result.shape, (8, 16)) | ||
else: | ||
self.assertEqual(result.shape, (32, 16)) # 8*4=32 | ||
|
||
def test_gather_along_first_dim_unequal_split(self, test_tensor, | ||
mock_group): | ||
"""test unequal split""" | ||
output_split_sizes = [5, 10, 15, 2] | ||
result = _gather_along_first_dim(test_tensor, mock_group, | ||
output_split_sizes) | ||
self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32 | ||
|
||
@pytest.mark.parametrize("world_size", [1, 4]) | ||
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, | ||
mock_dist, world_size): | ||
"""test _gather_along_last_dim""" | ||
mock_dist.get_world_size.return_value = world_size | ||
|
||
result = _gather_along_last_dim(test_tensor_last_dim, mock_group) | ||
|
||
self.assertEqual(result.shape, (8, 16, 32 * world_size)) | ||
|
||
@pytest.mark.parametrize("input_shape,expected_shape", [ | ||
((32, 16), (8, 16)), | ||
((40, 10), (10, 10)), | ||
]) | ||
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, | ||
expected_shape): | ||
input_tensor = torch.randn(*input_shape) | ||
result = _reduce_scatter_along_first_dim(input_tensor, mock_group) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
def test_reduce_scatter_along_last_dim(self, mock_group): | ||
input_tensor = torch.randn(8, 16, 32) | ||
result = _reduce_scatter_along_last_dim(input_tensor, mock_group) | ||
self.assertEqual(result.shape, (8, 16, 8)) | ||
|
||
@pytest.mark.parametrize("func,input_shape,expected_shape", [ | ||
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), | ||
(8, 16, 128)), | ||
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), | ||
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), | ||
(8, 16, 8)), | ||
("gather_from_sequence_parallel_region", (8, 16), (32, 16)), | ||
]) | ||
def test_wrapper_functions(self, mock_group, func, input_shape, | ||
expected_shape): | ||
"""test wrapper funcs""" | ||
mod = importlib.import_module( | ||
'vllm_ascend.distributed.tensor_parallel') | ||
globals = mod.__dict__ | ||
test_func = globals[func] | ||
input_tensor = torch.randn(*input_shape) | ||
result = test_func(input_tensor, mock_group) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
@pytest.mark.parametrize( | ||
"input_shape,output_shape", | ||
[ | ||
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] | ||
]) | ||
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): | ||
input_tensor = torch.randn(*input_shape) | ||
result = all_to_all_sp2hp(input_tensor, mock_group) | ||
self.assertEqual(result.shape, output_shape) | ||
|
||
@pytest.mark.parametrize( | ||
"input_shape,output_shape", | ||
[ | ||
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] | ||
]) | ||
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape): | ||
input_tensor = torch.randn(*input_shape) | ||
result = all_to_all_hp2sp(input_tensor, mock_group) | ||
self.assertEqual(result.shape, output_shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# Copyright 2023 The vLLM team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# This file is a part of the vllm-ascend project. | ||
|
||
import unittest | ||
|
||
import pytest | ||
from pytest_mock import MockerFixture | ||
|
||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( | ||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) | ||
from vllm_ascend.utils import adapt_patch # noqa E402 | ||
|
||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa | ||
|
||
adapt_patch(True) | ||
|
||
|
||
class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase): | ||
|
||
@pytest.fixture | ||
def config(self): | ||
config = MoEDispatcherConfig() | ||
config.set_num_local_experts(2) | ||
config.set_num_moe_experts(4) | ||
config.set_moe_pad_expert_input_to_capacity(False) | ||
config.set_moe_expert_capacity_factor(None) | ||
config.set_moe_router_topk(2) | ||
config.set_moe_grouped_gemm(False) | ||
config.set_group_topk(0) | ||
config.set_num_groups(1) | ||
config.set_is_fused(False) | ||
return config.build() | ||
|
||
def mock_ep_group(self, mocker): | ||
mock_group = mocker.MagicMock() | ||
mock_group.rank_in_group = 0 | ||
mock_group.world_size = 2 | ||
mock_group.device_group = "mock_group" | ||
return mock_group | ||
|
||
@pytest.fixture | ||
def dispatcher(self, config, mocker: MockerFixture): | ||
mocker.patch( | ||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", | ||
return_value=self.mock_ep_group(mocker)) | ||
return MoEAlltoAllSeqOverLapDispatcher(config) | ||
|
||
def test_initialization(self, dispatcher, config): | ||
self.assertEqual(dispatcher.num_local_experts, | ||
config.num_local_experts) | ||
self.assertEqual(dispatcher.num_experts, config.num_moe_experts) | ||
self.assertEqual(dispatcher.local_expert_indices, [0, 1]) | ||
self.assertEqual(dispatcher.ep_rank, 0) | ||
self.assertEqual(dispatcher.ep_size, 2) | ||
self.assertIsNotNone(dispatcher.overlap_stream) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.