Skip to content

[0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo #1547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 68 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
7ff288e
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jun 30, 2025
6d7b5b4
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
6a8e1a9
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
4805c5a
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
d68ce07
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
0aff693
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
f6ab19e
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
a94c094
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
91570d8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
e7c0d2d
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
47439e8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
cf3f1c8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
a4126f3
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
807aaf0
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
6f6efc1
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 2, 2025
305a0eb
handle conflict
Jul 8, 2025
5411ed6
add st:qwen3
Jul 8, 2025
3f88769
add st for moe token dispatcher
Jul 8, 2025
854c149
fix bug
harygo22 Jul 8, 2025
d0bd006
add st for moe token dispatcher
Jul 8, 2025
49e9771
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
a9bccf8
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
e31a7df
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
0a22312
[0.9.1][Perf] Optimize the number of rope-related index selections in…
whx-sjtu Jul 8, 2025
ee1dd49
[BUGFIX] FIX mtp accuraccy when temperture is not 0 (#1632)
JC-ut0 Jul 8, 2025
eef1093
add mc2 mask (#1642)
weiguihua2 Jul 8, 2025
eb54e22
[cherry-pick] static EPLB fix bug, add unit test to v0.9.1-dev (#1667)
songshanhu07 Jul 8, 2025
b02ad40
revert
harygo22 Jul 8, 2025
66807e0
fix bug
harygo22 Jul 8, 2025
d24758e
fix a bug
harygo22 Jul 8, 2025
d76c4fb
fix a bug
harygo22 Jul 8, 2025
f883902
ut test
harygo22 Jul 9, 2025
d5656f4
liscens & fix dsk dbo.
harygo22 Jul 9, 2025
df52070
handle conflict
Jul 9, 2025
adf3f74
handle code clean
Jul 9, 2025
5956ef0
handle code clean
Jul 9, 2025
af85566
handle code clean
Jul 9, 2025
d4ad734
handle code clean
Jul 9, 2025
847d52d
Merge branch 'v0.9.1-dev' into v0.9.1-dev
harygo22 Jul 10, 2025
3b7269a
fix comment
harygo22 Jul 10, 2025
deb4319
handle code clean
Jul 9, 2025
8b369df
handle code conflict
Jul 10, 2025
a8b3e15
fix init
harygo22 Jul 10, 2025
d290b7d
remove files & move sparsemoeblock to ops
harygo22 Jul 10, 2025
2c102d3
remove test
harygo22 Jul 11, 2025
565fa2d
clean code
harygo22 Jul 11, 2025
f980ad0
fix clean code
harygo22 Jul 11, 2025
969ee25
typo
harygo22 Jul 11, 2025
a70be9a
renaming cuda sync point
harygo22 Jul 11, 2025
1f09708
Merge branch 'v0.9.1-dev' of https://github.com/weijinqian0/vllm-asce…
Jul 11, 2025
402f889
handle code clean
Jul 11, 2025
141407d
handle code clean
Jul 11, 2025
b1d7305
handle code clean
Jul 11, 2025
62cebe1
handle clean code
Jul 11, 2025
80b1d0d
handle clean code
Jul 11, 2025
e87df11
handle clean code
Jul 11, 2025
267db60
handle clean code
Jul 11, 2025
b0572c8
handle clean code
Jul 11, 2025
e4f1050
handle clean code
Jul 11, 2025
eaed83d
handle clean code
Jul 11, 2025
b97baf4
handle clean code
Jul 11, 2025
d232d49
handle clean code
Jul 11, 2025
1e435e6
handle clean code
Jul 11, 2025
8effdd0
handle clean code
Jul 11, 2025
a8136b7
handle code clean
Jul 12, 2025
d29deae
handle code conflict
Jul 12, 2025
94b7b5b
handle code clean
Jul 12, 2025
c2f670d
fix header
harygo22 Jul 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ torch-npu==2.5.1.post1.dev20250619

# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
transformers<4.53.0
pytest_mock
24 changes: 24 additions & 0 deletions tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,30 @@ def test_models_distributed_DeepSeekV3_dbo():
vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
@patch.dict(os.environ, {
"VLLM_ASCEND_ENABLE_DBO": "1",
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"
})
def test_models_distributed_DeepSeekV3_alltoallv_dbo():
example_prompts = ["The president of the United States is"] * 10
dtype = "half"
sampling_params = SamplingParams(max_tokens=30, temperature=0.0)
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
model_arch = 'DeepseekV3ForCausalLM'
registed_models = ModelRegistry.models
assert registed_models[
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
assert registed_models[
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
vllm_model.generate(example_prompts, sampling_params)


def test_models_distributed_DeepSeek_W8A8():
example_prompts = [
"Hello, my name is",
Expand Down
139 changes: 139 additions & 0 deletions tests/ut/test_distributed_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.

import importlib
import unittest
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm_ascend.distributed.tensor_parallel import (
_gather_along_first_dim, _gather_along_last_dim,
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
all_to_all_hp2sp, all_to_all_sp2hp)


@pytest.fixture
def test_tensor():
return torch.randn(8, 16)


@pytest.fixture
def test_tensor_last_dim():
return torch.randn(8, 16, 32)


@pytest.fixture
def mock_group():
return MagicMock()


@pytest.fixture(autouse=True)
def mock_dist():
with patch("torch.distributed") as mock:
mock.get_world_size.return_value = 4
mock.get_rank.return_value = 0
yield mock


class TestDistributedCommunication(unittest.TestCase):

@pytest.mark.parametrize("world_size", [1, 4])
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
world_size):
"""test _gather_along_first_dim"""
mock_dist.get_world_size.return_value = world_size

result = _gather_along_first_dim(test_tensor, mock_group)

if world_size == 1:
self.assertEqual(result.shape, (8, 16))
else:
self.assertEqual(result.shape, (32, 16)) # 8*4=32

def test_gather_along_first_dim_unequal_split(self, test_tensor,
mock_group):
"""test unequal split"""
output_split_sizes = [5, 10, 15, 2]
result = _gather_along_first_dim(test_tensor, mock_group,
output_split_sizes)
self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32

@pytest.mark.parametrize("world_size", [1, 4])
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group,
mock_dist, world_size):
"""test _gather_along_last_dim"""
mock_dist.get_world_size.return_value = world_size

result = _gather_along_last_dim(test_tensor_last_dim, mock_group)

self.assertEqual(result.shape, (8, 16, 32 * world_size))

@pytest.mark.parametrize("input_shape,expected_shape", [
((32, 16), (8, 16)),
((40, 10), (10, 10)),
])
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape,
expected_shape):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_first_dim(input_tensor, mock_group)
self.assertEqual(result.shape, expected_shape)

def test_reduce_scatter_along_last_dim(self, mock_group):
input_tensor = torch.randn(8, 16, 32)
result = _reduce_scatter_along_last_dim(input_tensor, mock_group)
self.assertEqual(result.shape, (8, 16, 8))

@pytest.mark.parametrize("func,input_shape,expected_shape", [
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
(8, 16, 128)),
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
(8, 16, 8)),
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
])
def test_wrapper_functions(self, mock_group, func, input_shape,
expected_shape):
"""test wrapper funcs"""
mod = importlib.import_module(
'vllm_ascend.distributed.tensor_parallel')
globals = mod.__dict__
test_func = globals[func]
input_tensor = torch.randn(*input_shape)
result = test_func(input_tensor, mock_group)
self.assertEqual(result.shape, expected_shape)

@pytest.mark.parametrize(
"input_shape,output_shape",
[
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
])
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape):
input_tensor = torch.randn(*input_shape)
result = all_to_all_sp2hp(input_tensor, mock_group)
self.assertEqual(result.shape, output_shape)

@pytest.mark.parametrize(
"input_shape,output_shape",
[
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
])
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape):
input_tensor = torch.randn(*input_shape)
result = all_to_all_hp2sp(input_tensor, mock_group)
self.assertEqual(result.shape, output_shape)
69 changes: 69 additions & 0 deletions tests/ut/test_token_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.

import unittest

import pytest
from pytest_mock import MockerFixture

from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.utils import adapt_patch # noqa E402

import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa

adapt_patch(True)


class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase):

@pytest.fixture
def config(self):
config = MoEDispatcherConfig()
config.set_num_local_experts(2)
config.set_num_moe_experts(4)
config.set_moe_pad_expert_input_to_capacity(False)
config.set_moe_expert_capacity_factor(None)
config.set_moe_router_topk(2)
config.set_moe_grouped_gemm(False)
config.set_group_topk(0)
config.set_num_groups(1)
config.set_is_fused(False)
return config.build()

def mock_ep_group(self, mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
return mock_group

@pytest.fixture
def dispatcher(self, config, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
return_value=self.mock_ep_group(mocker))
return MoEAlltoAllSeqOverLapDispatcher(config)

def test_initialization(self, dispatcher, config):
self.assertEqual(dispatcher.num_local_experts,
config.num_local_experts)
self.assertEqual(dispatcher.num_experts, config.num_moe_experts)
self.assertEqual(dispatcher.local_expert_indices, [0, 1])
self.assertEqual(dispatcher.ep_rank, 0)
self.assertEqual(dispatcher.ep_size, 2)
self.assertIsNotNone(dispatcher.overlap_stream)
76 changes: 45 additions & 31 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@
from vllm.platforms import current_platform

import vllm_ascend.envs as envs
import vllm_ascend.envs as envs_ascend


class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
MC2_PREFILL = 3
All2AllSeq = 4


# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool):
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2
if ep_size == 1:
return FusedMoEState.AllGather
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
return (FusedMoEState.All2AllSeq if
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
elif ep_size >= 16 and with_prefill and enable_chunk_mc2:
return FusedMoEState.MC2_PREFILL
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
Expand All @@ -35,27 +41,30 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool):

@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None):
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
with set_forward_context(attn_metadata,
vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
with set_forward_context(
attn_metadata,
vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
forward_context = get_forward_context()
forward_context.with_prefill = with_prefill
ep_size = torch.distributed.get_world_size(
) if vllm_config.parallel_config.enable_expert_parallel else 1
ep_size = (torch.distributed.get_world_size() if
vllm_config.parallel_config.enable_expert_parallel else 1)

fused_moe_state = get_fused_moe_state(ep_size, with_prefill)

Expand All @@ -68,20 +77,21 @@ def set_ascend_forward_context(
forward_context.capturing = False

if num_tokens is None and attn_metadata is not None:
if hasattr(attn_metadata, 'num_actual_tokens'):
if hasattr(attn_metadata, "num_actual_tokens"):
# for v1 engine
num_tokens = attn_metadata.num_actual_tokens
else:
# for v0 engine
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
num_tokens = (attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens)

if num_actual_tokens is None:
num_actual_tokens = num_tokens

dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
)
max_tokens_across_dp = (
forward_context.dp_metadata.max_tokens_across_dp_cpu.item())
else:
max_tokens_across_dp = num_tokens

Expand All @@ -91,29 +101,33 @@ def set_ascend_forward_context(
tp_world_size = get_tp_group().world_size
world_size = torch.distributed.get_world_size()
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
forward_context.padded_num_tokens = (
math.ceil(max_tokens_across_dp / tp_world_size) *
tp_world_size)
# NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs.
forward_context.global_bs = math.ceil(
max_tokens_across_dp / tp_world_size) * world_size
forward_context.global_bs = (
math.ceil(max_tokens_across_dp / tp_world_size) * world_size)

if fused_moe_state == FusedMoEState.MC2_PREFILL:
chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
forward_context.max_num_chunks = math.ceil(
math.ceil(max_tokens_across_dp / tp_world_size) /
chunk_size)

forward_context.global_bs = math.ceil(
forward_context.global_bs = (math.ceil(
math.ceil(max_tokens_across_dp / tp_world_size) /
forward_context.max_num_chunks) * world_size
forward_context.max_num_chunks) * world_size)

min_num_tokens = forward_context.max_num_chunks * tp_world_size
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / min_num_tokens) * min_num_tokens

mc2_mask = torch.zeros(forward_context.padded_num_tokens,
dtype=torch.bool,
device=current_platform.device_type)
forward_context.padded_num_tokens = (
math.ceil(max_tokens_across_dp / min_num_tokens) *
min_num_tokens)

mc2_mask = torch.zeros(
forward_context.padded_num_tokens,
dtype=torch.bool,
device=current_platform.device_type,
)
mc2_mask[:num_actual_tokens] = True
forward_context.mc2_mask = mc2_mask

Expand Down
Loading