Skip to content

[0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo #1547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 68 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
7ff288e
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jun 30, 2025
6d7b5b4
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
6a8e1a9
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
4805c5a
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
d68ce07
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
0aff693
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
f6ab19e
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
a94c094
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
91570d8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
e7c0d2d
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
47439e8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
cf3f1c8
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
a4126f3
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
807aaf0
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 1, 2025
6f6efc1
[Feature]Moe alltoallv communication optimization for unquantized RL …
Jul 2, 2025
305a0eb
handle conflict
Jul 8, 2025
5411ed6
add st:qwen3
Jul 8, 2025
3f88769
add st for moe token dispatcher
Jul 8, 2025
854c149
fix bug
harygo22 Jul 8, 2025
d0bd006
add st for moe token dispatcher
Jul 8, 2025
49e9771
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
a9bccf8
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
e31a7df
add moe_block: AscendSparseMoeBlock
Jul 8, 2025
0a22312
[0.9.1][Perf] Optimize the number of rope-related index selections in…
whx-sjtu Jul 8, 2025
ee1dd49
[BUGFIX] FIX mtp accuraccy when temperture is not 0 (#1632)
JC-ut0 Jul 8, 2025
eef1093
add mc2 mask (#1642)
weiguihua2 Jul 8, 2025
eb54e22
[cherry-pick] static EPLB fix bug, add unit test to v0.9.1-dev (#1667)
songshanhu07 Jul 8, 2025
b02ad40
revert
harygo22 Jul 8, 2025
66807e0
fix bug
harygo22 Jul 8, 2025
d24758e
fix a bug
harygo22 Jul 8, 2025
d76c4fb
fix a bug
harygo22 Jul 8, 2025
f883902
ut test
harygo22 Jul 9, 2025
d5656f4
liscens & fix dsk dbo.
harygo22 Jul 9, 2025
df52070
handle conflict
Jul 9, 2025
adf3f74
handle code clean
Jul 9, 2025
5956ef0
handle code clean
Jul 9, 2025
af85566
handle code clean
Jul 9, 2025
d4ad734
handle code clean
Jul 9, 2025
847d52d
Merge branch 'v0.9.1-dev' into v0.9.1-dev
harygo22 Jul 10, 2025
3b7269a
fix comment
harygo22 Jul 10, 2025
deb4319
handle code clean
Jul 9, 2025
8b369df
handle code conflict
Jul 10, 2025
a8b3e15
fix init
harygo22 Jul 10, 2025
d290b7d
remove files & move sparsemoeblock to ops
harygo22 Jul 10, 2025
2c102d3
remove test
harygo22 Jul 11, 2025
565fa2d
clean code
harygo22 Jul 11, 2025
f980ad0
fix clean code
harygo22 Jul 11, 2025
969ee25
typo
harygo22 Jul 11, 2025
a70be9a
renaming cuda sync point
harygo22 Jul 11, 2025
1f09708
Merge branch 'v0.9.1-dev' of https://github.com/weijinqian0/vllm-asce…
Jul 11, 2025
402f889
handle code clean
Jul 11, 2025
141407d
handle code clean
Jul 11, 2025
b1d7305
handle code clean
Jul 11, 2025
62cebe1
handle clean code
Jul 11, 2025
80b1d0d
handle clean code
Jul 11, 2025
e87df11
handle clean code
Jul 11, 2025
267db60
handle clean code
Jul 11, 2025
b0572c8
handle clean code
Jul 11, 2025
e4f1050
handle clean code
Jul 11, 2025
eaed83d
handle clean code
Jul 11, 2025
b97baf4
handle clean code
Jul 11, 2025
d232d49
handle clean code
Jul 11, 2025
1e435e6
handle clean code
Jul 11, 2025
8effdd0
handle clean code
Jul 11, 2025
a8136b7
handle code clean
Jul 12, 2025
d29deae
handle code conflict
Jul 12, 2025
94b7b5b
handle code clean
Jul 12, 2025
c2f670d
fix header
harygo22 Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/multicard/test_qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""
Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/multicard/test_data_parallel.py`.
"""

import os
import subprocess
import sys
from unittest.mock import patch

import pytest

MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@patch.dict(
os.environ, {
"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3",
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1",
"VLLM_ASCEND_ENABLE_DBO": "1"
})
def test_qwen3_moe_inference(model, max_tokens):
script = "examples/offline_data_parallel.py"

env = os.environ.copy()

cmd = [
sys.executable,
script,
"--model",
model,
"--dp-size",
"2",
"--tp-size",
"2",
"--node-size",
"1",
"--node-rank",
"0",
"--trust-remote-code",
"--enforce-eager",
]

print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600)
output = proc.stdout.decode()

print(output)

assert "DP rank 0 needs to process" in output
assert "DP rank 1 needs to process" in output
assert "Generated text:" in output
assert proc.returncode == 0
127 changes: 127 additions & 0 deletions tests/ut/test_distributed_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong header

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import importlib
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm_ascend.distributed.tensor_parallel import (
_gather_along_first_dim, _gather_along_last_dim,
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
all_to_all_hp2sp, all_to_all_sp2hp)


@pytest.fixture
def test_tensor():
return torch.randn(8, 16)


@pytest.fixture
def test_tensor_last_dim():
return torch.randn(8, 16, 32)


@pytest.fixture
def mock_group():
return MagicMock()


@pytest.fixture(autouse=True)
def mock_dist():
with patch("torch.distributed") as mock:
mock.get_world_size.return_value = 4
mock.get_rank.return_value = 0
yield mock


class TestDistributedCommunication:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to inherits from testbase:

class TestBase(unittest.TestCase):

to make sure the test will not be impacted in future patch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ut/base.py is removed in v0.9.1-dev, we inherits from unittest.TestCase instead.


@pytest.mark.parametrize("world_size", [1, 4])
def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist,
world_size):
"""test _gather_along_first_dim"""
mock_dist.get_world_size.return_value = world_size

result = _gather_along_first_dim(test_tensor, mock_group)

if world_size == 1:
assert torch.equal(result, test_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use self.assertXXX

also same for other changes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
assert result.shape == (32, 16) # 8*4=32

def test_gather_along_first_dim_unequal_split(self, test_tensor,
mock_group):
"""test unequal split"""
output_split_sizes = [5, 10, 15, 2]
result = _gather_along_first_dim(test_tensor, mock_group,
output_split_sizes)
assert result.shape == (32, 16) # 5+10+15+2=32

@pytest.mark.parametrize("world_size", [1, 4])
def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group,
mock_dist, world_size):
"""test _gather_along_last_dim"""
mock_dist.get_world_size.return_value = world_size

result = _gather_along_last_dim(test_tensor_last_dim, mock_group)

if world_size == 1:
assert torch.equal(result, test_tensor_last_dim)
else:
assert result.shape == (8, 16, 32 * world_size) # 8*4=32

@pytest.mark.parametrize("input_shape,expected_shape", [
((32, 16), (8, 16)),
((40, 10), (10, 10)),
])
def test_reduce_scatter_along_first_dim(self, mock_group, input_shape,
expected_shape):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_first_dim(input_tensor, mock_group)
assert result.shape == expected_shape

def test_reduce_scatter_along_last_dim(self, mock_group):
input_tensor = torch.randn(8, 16, 32)
result = _reduce_scatter_along_last_dim(input_tensor, mock_group)
assert result.shape == (8, 16, 8) # 32/4=8

@pytest.mark.parametrize("func,input_shape,expected_shape", [
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
(8, 16, 128)),
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
(8, 16, 8)),
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
])
def test_wrapper_functions(self, mock_group, func, input_shape,
expected_shape):
"""test wrapper funcs"""
mod = importlib.import_module(
'vllm_ascend.distributed.tensor_parallel')
globals = mod.__dict__
test_func = globals[func]
input_tensor = torch.randn(*input_shape)
result = test_func(input_tensor, mock_group)
assert result.shape == expected_shape

@pytest.mark.parametrize(
"input_shape,output_shape",
[
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
])
def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape):
input_tensor = torch.randn(*input_shape)
result = all_to_all_sp2hp(input_tensor, mock_group)
assert result.shape == output_shape

@pytest.mark.parametrize(
"input_shape,output_shape",
[
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
])
def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape):
input_tensor = torch.randn(*input_shape)
result = all_to_all_hp2sp(input_tensor, mock_group)
assert result.shape == output_shape
169 changes: 169 additions & 0 deletions tests/ut/test_moe_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import math

import pytest
import torch

from vllm_ascend.ops.moe_dispatcher.moe_utils import (
get_capacity, group_limited_topk, permute, sort_chunks_by_idxs,
topk_softmax_with_capacity, unpermute)

import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa



class TestMoeUtils:

@pytest.fixture
def setup(self):
self.num_tokens = 16
self.num_experts = 4
self.hidden_size = 8
self.topk = 2
self.capacity_factor = 1.0
self.group_topk = 2
self.num_groups = 2
self.scaling_factor = 1.0

def test_group_limited_topk(self, setup):
# Test group-limited topk routing
scores = torch.randn(self.num_tokens, self.num_experts)
probs, indices = group_limited_topk(scores,
topk=self.topk,
num_tokens=self.num_tokens,
num_experts=self.num_experts,
num_groups=self.num_groups,
group_topk=self.group_topk)

assert probs.shape == (self.num_tokens, self.topk)
assert indices.shape == (self.num_tokens, self.topk)
assert torch.all(indices < self.num_experts)

@pytest.mark.parametrize("score_function", ["softmax"])
def test_topk_softmax_with_capacity(self, setup, score_function):
# Test topk softmax with capacity
logits = torch.randn(self.num_tokens, self.num_experts)

# Test without capacity
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
logits, topk=self.topk, score_function=score_function)
assert probs.shape == (self.num_tokens, self.num_experts)
assert routing_map.shape == (self.num_tokens, self.num_experts)
assert tokens_per_expert.shape == (self.num_experts, )

# Test with group routing
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
logits,
topk=self.topk,
num_groups=self.num_groups,
group_topk=self.group_topk,
score_function=score_function)
assert probs.shape == (self.num_tokens, self.num_experts)

def test_get_capacity(self, setup):
# Test capacity calculation
capacity = get_capacity(num_tokens=self.num_tokens,
num_experts=self.num_experts,
capacity_factor=self.capacity_factor)
expected = math.ceil(
(self.num_tokens / self.num_experts) * self.capacity_factor)
assert capacity == expected

# Test with min capacity
min_capacity = 5
capacity = get_capacity(num_tokens=self.num_tokens,
num_experts=self.num_experts,
capacity_factor=self.capacity_factor,
min_capacity=min_capacity)
assert capacity == min_capacity

def test_permute(self, setup):
# Test token permutation
tokens = torch.randn(self.num_tokens, self.hidden_size)
routing_map = torch.randint(
0, 2, (self.num_tokens, self.num_experts)).bool()

# Basic permutation
permuted_tokens, sorted_indices = permute(tokens, routing_map)
assert permuted_tokens.shape[0] == routing_map.sum()
assert sorted_indices.shape[0] == routing_map.sum()

# With drop and pad
capacity = get_capacity(num_tokens=self.num_tokens * self.topk,
num_experts=self.num_experts,
capacity_factor=self.capacity_factor)
num_out_tokens = capacity * self.num_experts
permuted_tokens, sorted_indices = permute(
tokens,
routing_map,
num_out_tokens=num_out_tokens,
drop_and_pad=True)
assert permuted_tokens.shape[0] == num_out_tokens
assert sorted_indices.shape[0] == num_out_tokens

def test_unpermute(self, setup):
# Test token unpermutation
tokens = torch.randn(self.num_tokens, self.hidden_size)
routing_map = torch.randint(
0, 2, (self.num_tokens, self.num_experts)).bool()
probs = torch.rand(self.num_tokens, self.num_experts)

# First permute
permuted_tokens, sorted_indices = permute(tokens, routing_map)

# Then unpermute
restored_tokens = unpermute(permuted_tokens,
sorted_indices,
tokens.shape,
probs=probs,
routing_map=routing_map)
assert restored_tokens.shape == tokens.shape

# With drop and pad
capacity = get_capacity(num_tokens=self.num_tokens * self.topk,
num_experts=self.num_experts,
capacity_factor=self.capacity_factor)
num_out_tokens = capacity * self.num_experts
permuted_tokens, sorted_indices = permute(
tokens,
routing_map,
num_out_tokens=num_out_tokens,
drop_and_pad=True)
restored_tokens = unpermute(permuted_tokens,
sorted_indices,
tokens.shape,
probs=probs,
routing_map=routing_map,
drop_and_pad=True)
assert restored_tokens.shape == tokens.shape

def test_sort_chunks_by_idxs(self, setup):
# Test chunk sorting
input_tensor = torch.randn(10, self.hidden_size)
split_sizes = torch.tensor([3, 2, 5])
sorted_idxs = torch.tensor([2, 0, 1])

output = sort_chunks_by_idxs(input_tensor, split_sizes, sorted_idxs)
assert output.shape == input_tensor.shape

# Verify the order is correct
expected = torch.cat(
[input_tensor[5:], input_tensor[0:3], input_tensor[3:5]])
assert torch.allclose(output, expected)

@pytest.mark.parametrize("score_function", ["softmax"])
def test_score_functions(self, setup, score_function):
# Test different score functions
logits = torch.randn(self.num_tokens, self.num_experts)
expert_bias = torch.randn(self.num_experts)

probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
logits,
topk=self.topk,
score_function=score_function,
expert_bias=expert_bias)
assert probs.shape == (self.num_tokens, self.num_experts)
assert routing_map.shape == (self.num_tokens, self.num_experts)
assert tokens_per_expert.shape == (self.num_experts, )
Loading