-
Notifications
You must be signed in to change notification settings - Fork 279
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo #1547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 46 commits
7ff288e
6d7b5b4
6a8e1a9
4805c5a
d68ce07
0aff693
f6ab19e
a94c094
91570d8
e7c0d2d
47439e8
cf3f1c8
a4126f3
807aaf0
6f6efc1
305a0eb
5411ed6
3f88769
854c149
d0bd006
49e9771
a9bccf8
e31a7df
0a22312
ee1dd49
eef1093
eb54e22
b02ad40
66807e0
d24758e
d76c4fb
f883902
d5656f4
df52070
adf3f74
5956ef0
af85566
d4ad734
847d52d
3b7269a
deb4319
8b369df
a8b3e15
d290b7d
2c102d3
565fa2d
f980ad0
969ee25
a70be9a
1f09708
402f889
141407d
b1d7305
62cebe1
80b1d0d
e87df11
267db60
b0572c8
e4f1050
eaed83d
b97baf4
d232d49
1e435e6
8effdd0
a8136b7
d29deae
94b7b5b
c2f670d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# Copyright 2023 The vLLM team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# This file is a part of the vllm-ascend project. | ||
# | ||
""" | ||
Compare the outputs of vLLM with and without aclgraph. | ||
Run `pytest tests/multicard/test_data_parallel.py`. | ||
""" | ||
|
||
import os | ||
import subprocess | ||
import sys | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@patch.dict( | ||
os.environ, { | ||
"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3", | ||
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", | ||
"VLLM_ASCEND_ENABLE_DBO": "1" | ||
}) | ||
def test_qwen3_moe_inference(model, max_tokens): | ||
script = "examples/offline_data_parallel.py" | ||
|
||
env = os.environ.copy() | ||
|
||
cmd = [ | ||
sys.executable, | ||
script, | ||
"--model", | ||
model, | ||
"--dp-size", | ||
"2", | ||
"--tp-size", | ||
"2", | ||
"--node-size", | ||
"1", | ||
"--node-rank", | ||
"0", | ||
"--trust-remote-code", | ||
"--enforce-eager", | ||
] | ||
|
||
print(f"Running subprocess: {' '.join(cmd)}") | ||
proc = subprocess.run(cmd, | ||
env=env, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.STDOUT, | ||
timeout=600) | ||
output = proc.stdout.decode() | ||
|
||
print(output) | ||
|
||
assert "DP rank 0 needs to process" in output | ||
assert "DP rank 1 needs to process" in output | ||
assert "Generated text:" in output | ||
assert proc.returncode == 0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# Copyright 2023 The vLLM team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# This file is a part of the vllm-ascend project. | ||
|
||
import importlib | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm_ascend.distributed.tensor_parallel import ( | ||
_gather_along_first_dim, _gather_along_last_dim, | ||
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, | ||
all_to_all_hp2sp, all_to_all_sp2hp) | ||
|
||
|
||
@pytest.fixture | ||
def test_tensor(): | ||
return torch.randn(8, 16) | ||
|
||
|
||
@pytest.fixture | ||
def test_tensor_last_dim(): | ||
return torch.randn(8, 16, 32) | ||
|
||
|
||
@pytest.fixture | ||
def mock_group(): | ||
return MagicMock() | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_dist(): | ||
with patch("torch.distributed") as mock: | ||
mock.get_world_size.return_value = 4 | ||
mock.get_rank.return_value = 0 | ||
yield mock | ||
|
||
|
||
class TestDistributedCommunication(unittest.TestCase): | ||
|
||
@pytest.mark.parametrize("world_size", [1, 4]) | ||
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, | ||
world_size): | ||
"""test _gather_along_first_dim""" | ||
mock_dist.get_world_size.return_value = world_size | ||
|
||
result = _gather_along_first_dim(test_tensor, mock_group) | ||
|
||
if world_size == 1: | ||
self.assertEqual(result.shape, (8, 16)) | ||
else: | ||
self.assertEqual(result.shape, (32, 16)) # 8*4=32 | ||
|
||
def test_gather_along_first_dim_unequal_split(self, test_tensor, | ||
mock_group): | ||
"""test unequal split""" | ||
output_split_sizes = [5, 10, 15, 2] | ||
result = _gather_along_first_dim(test_tensor, mock_group, | ||
output_split_sizes) | ||
self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32 | ||
|
||
@pytest.mark.parametrize("world_size", [1, 4]) | ||
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, | ||
mock_dist, world_size): | ||
"""test _gather_along_last_dim""" | ||
mock_dist.get_world_size.return_value = world_size | ||
|
||
result = _gather_along_last_dim(test_tensor_last_dim, mock_group) | ||
|
||
self.assertEqual(result.shape, (8, 16, 32 * world_size)) | ||
|
||
@pytest.mark.parametrize("input_shape,expected_shape", [ | ||
((32, 16), (8, 16)), | ||
((40, 10), (10, 10)), | ||
]) | ||
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, | ||
expected_shape): | ||
input_tensor = torch.randn(*input_shape) | ||
result = _reduce_scatter_along_first_dim(input_tensor, mock_group) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
def test_reduce_scatter_along_last_dim(self, mock_group): | ||
input_tensor = torch.randn(8, 16, 32) | ||
result = _reduce_scatter_along_last_dim(input_tensor, mock_group) | ||
self.assertEqual(result.shape, (8, 16, 8)) | ||
|
||
@pytest.mark.parametrize("func,input_shape,expected_shape", [ | ||
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), | ||
(8, 16, 128)), | ||
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), | ||
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), | ||
(8, 16, 8)), | ||
("gather_from_sequence_parallel_region", (8, 16), (32, 16)), | ||
]) | ||
def test_wrapper_functions(self, mock_group, func, input_shape, | ||
expected_shape): | ||
"""test wrapper funcs""" | ||
mod = importlib.import_module( | ||
'vllm_ascend.distributed.tensor_parallel') | ||
globals = mod.__dict__ | ||
test_func = globals[func] | ||
input_tensor = torch.randn(*input_shape) | ||
result = test_func(input_tensor, mock_group) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
@pytest.mark.parametrize( | ||
"input_shape,output_shape", | ||
[ | ||
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] | ||
]) | ||
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): | ||
input_tensor = torch.randn(*input_shape) | ||
result = all_to_all_sp2hp(input_tensor, mock_group) | ||
self.assertEqual(result.shape, output_shape) | ||
|
||
@pytest.mark.parametrize( | ||
"input_shape,output_shape", | ||
[ | ||
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] | ||
]) | ||
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape): | ||
input_tensor = torch.randn(*input_shape) | ||
result = all_to_all_hp2sp(input_tensor, mock_group) | ||
self.assertEqual(result.shape, output_shape) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# Copyright 2023 The vLLM team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# This file is a part of the vllm-ascend project. | ||
|
||
import pytest | ||
import torch | ||
import unittest | ||
from pytest_mock import MockerFixture | ||
|
||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( | ||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) | ||
from vllm_ascend.utils import adapt_patch # noqa E402 | ||
|
||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa | ||
|
||
|
||
|
||
adapt_patch(True) | ||
|
||
|
||
class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase): | ||
|
||
@pytest.fixture | ||
def config(self): | ||
config = MoEDispatcherConfig() | ||
config.set_num_local_experts(2) | ||
config.set_num_moe_experts(4) | ||
config.set_moe_pad_expert_input_to_capacity(False) | ||
config.set_moe_expert_capacity_factor(None) | ||
config.set_moe_router_topk(2) | ||
config.set_moe_grouped_gemm(False) | ||
config.set_group_topk(0) | ||
config.set_num_groups(1) | ||
config.set_is_fused(False) | ||
return config.build() | ||
|
||
def mock_ep_group(self, mocker): | ||
mock_group = mocker.MagicMock() | ||
mock_group.rank_in_group = 0 | ||
mock_group.world_size = 2 | ||
mock_group.device_group = "mock_group" | ||
return mock_group | ||
|
||
@pytest.fixture | ||
def dispatcher(self, config, mocker: MockerFixture): | ||
mocker.patch( | ||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", | ||
return_value=self.mock_ep_group(mocker)) | ||
return MoEAlltoAllSeqOverLapDispatcher(config) | ||
|
||
def test_initialization(self, dispatcher, config): | ||
self.assertEqual(dispatcher.num_local_experts, config.num_local_experts) | ||
self.assertEqual(dispatcher.num_experts, config.num_moe_experts) | ||
self.assertEqual(dispatcher.local_expert_indices, [0, 1]) | ||
self.assertEqual(dispatcher.ep_rank, 0) | ||
self.assertEqual(dispatcher.ep_size, 2) | ||
self.assertIsNotNone(dispatcher.overlap_stream) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,19 +11,26 @@ | |
|
||
import vllm_ascend.envs as envs | ||
|
||
import vllm_ascend.envs as envs_ascend | ||
|
||
|
||
class FusedMoEState(Enum): | ||
AllGather = 0 | ||
All2All = 1 | ||
MC2 = 2 | ||
MC2_PREFILL = 3 | ||
All2AllSeq = 4 | ||
|
||
|
||
# TODO(zzzzwwjj): add soc_version to choose branch | ||
def get_fused_moe_state(ep_size: int, with_prefill: bool): | ||
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2 | ||
if ep_size == 1: | ||
return FusedMoEState.AllGather | ||
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: | ||
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. | ||
return FusedMoEState.All2AllSeq if ( | ||
ep_size < 16 or with_prefill) else FusedMoEState.MC2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why there is this restriction There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MC2 Dispatch/Combine is still faster than alltoall_seq in decoding stage. so when ep_size >= 16, use MC2 for better performance. |
||
elif ep_size >= 16 and with_prefill and enable_chunk_mc2: | ||
return FusedMoEState.MC2_PREFILL | ||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. | ||
|
Uh oh!
There was an error while loading. Please reload this page.