Skip to content

Commit af4917a

Browse files
committed
Merge remote-tracking branch 'wjq/v0.9.1-dev' into wjq_091_dev
2 parents a6838f7 + d556d49 commit af4917a

File tree

14 files changed

+490
-297
lines changed

14 files changed

+490
-297
lines changed

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def test_mtp_torchair_correctness(
114114
enforce_eager=False,
115115
additional_config={
116116
"torchair_graph_config": {
117-
"enabled": True
117+
"enabled": True,
118+
"graph_batch_size": [256]
118119
},
119120
"ascend_scheduler_config": {
120121
"enabled": True
@@ -132,7 +133,8 @@ def test_mtp_torchair_correctness(
132133
},
133134
additional_config={
134135
"torchair_graph_config": {
135-
"enabled": True
136+
"enabled": True,
137+
"graph_batch_size": [256]
136138
},
137139
"ascend_scheduler_config": {
138140
"enabled": True
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# fused moe ops test will hit the infer_schema error, we need add the patch
2+
# here to make the test pass.
3+
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
4+
5+
import json
6+
import unittest
7+
from typing import List, TypedDict
8+
from unittest import mock
9+
10+
import torch
11+
12+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
13+
14+
15+
class Device(TypedDict):
16+
device_id: int
17+
device_expert: List[int]
18+
19+
20+
class Layer(TypedDict):
21+
layer_id: int
22+
device_count: int
23+
device_list: List[Device]
24+
25+
26+
class MockData(TypedDict):
27+
moe_layer_count: int
28+
layer_list: List[Layer]
29+
30+
31+
MOCK_DATA: MockData = {
32+
"moe_layer_count":
33+
1,
34+
"layer_list": [{
35+
"layer_id":
36+
0,
37+
"device_count":
38+
2,
39+
"device_list": [{
40+
"device_id": 0,
41+
"device_expert": [7, 2, 0, 3, 5]
42+
}, {
43+
"device_id": 1,
44+
"device_expert": [6, 1, 4, 7, 2]
45+
}]
46+
}]
47+
}
48+
49+
50+
class TestExpertLoadBalancer(unittest.TestCase):
51+
52+
def setUp(self):
53+
json_file = "expert_map.json"
54+
with open(json_file, 'w') as f:
55+
json.dump(MOCK_DATA, f)
56+
57+
self.expert_load_balancer = ExpertLoadBalancer(json_file,
58+
global_expert_num=8)
59+
60+
def test_init(self):
61+
62+
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
63+
torch.Tensor)
64+
self.assertEqual(self.expert_load_balancer.layers_num,
65+
MOCK_DATA["moe_layer_count"])
66+
self.assertEqual(self.expert_load_balancer.ranks_num,
67+
MOCK_DATA["layer_list"][0]["device_count"])
68+
69+
def test_generate_index_dicts(self):
70+
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
71+
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
72+
expected_result = [{
73+
7: 0,
74+
2: 1,
75+
0: 2,
76+
3: 3,
77+
5: 4
78+
}, {
79+
6: 5,
80+
1: 6,
81+
4: 7,
82+
7: 8,
83+
2: 9
84+
}]
85+
self.assertEqual(result, expected_result)
86+
87+
def test_generate_expert_placement_map(self):
88+
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
89+
)
90+
self.assertEqual(expert_placement_map.shape,
91+
(self.expert_load_balancer.layers_num,
92+
self.expert_load_balancer.ranks_num, 8))
93+
self.assertTrue(torch.all(expert_placement_map >= -1))
94+
95+
def test_generate_log2phy_expert_map(self):
96+
layer_id = 0
97+
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
98+
layer_id)
99+
self.assertEqual(log2phy_map.shape,
100+
(self.expert_load_balancer.ranks_num, 8))
101+
self.assertTrue(torch.all(log2phy_map >= -1))
102+
103+
@mock.patch("torch_npu.npu._lazy_init")
104+
@mock.patch("torch.npu.current_device", return_value="cpu")
105+
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
106+
layer_id = 0
107+
rank_id = 0
108+
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
109+
layer_id, rank_id)
110+
self.assertEqual(rank_local_expert_num, 5)
111+
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
112+
dtype=torch.int32).to(
113+
rank_expert_map.device)
114+
self.assertTrue(rank_expert_map.equal(expected_tensor))
115+
116+
rank_id = 1
117+
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
118+
layer_id, rank_id)
119+
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
120+
dtype=torch.int32).to(
121+
rank_expert_map.device)
122+
self.assertTrue(rank_expert_map.equal(expected_tensor))
123+
124+
def test_get_rank_log2phy_map(self):
125+
layer_id = 0
126+
rank_id = 0
127+
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
128+
layer_id, rank_id)
129+
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
130+
dtype=torch.int32).to(
131+
log2phy_map.device)
132+
self.assertTrue(log2phy_map.equal(expected_tensor))
133+
134+
rank_id = 1
135+
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
136+
layer_id, rank_id)
137+
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
138+
dtype=torch.int32).to(
139+
log2phy_map.device)
140+
self.assertTrue(log2phy_map.equal(expected_tensor))
141+
142+
def test_get_global_redundant_expert_num(self):
143+
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
144+
)
145+
expected_redundant_expert_num = len(MOCK_DATA["layer_list"][0]["device_list"][0]["device_expert"]) * \
146+
MOCK_DATA["layer_list"][0]["device_count"] - 8
147+
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)

tests/ut/moe_util.py renamed to tests/ut/test_moe_util.py

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch
55
import pytest
66
import math
7+
import vllm_ascend.patch.worker.patch_common.patch_utils
78

8-
from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, \
9-
group_limited_topk, unpermute, sort_chunks_by_idxs
9+
from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, group_limited_topk, unpermute, sort_chunks_by_idxs
1010

1111

1212
class TestMoeUtils:
@@ -22,6 +22,7 @@ def setup(self):
2222
self.num_groups = 2
2323
self.scaling_factor = 1.0
2424

25+
2526
def test_group_limited_topk(self, setup):
2627
# Test group-limited topk routing
2728
scores = torch.randn(self.num_tokens, self.num_experts)
@@ -38,42 +39,33 @@ def test_group_limited_topk(self, setup):
3839
assert indices.shape == (self.num_tokens, self.topk)
3940
assert torch.all(indices < self.num_experts)
4041

41-
def test_topk_softmax_with_capacity(self, setup):
42+
43+
@pytest.mark.parametrize("score_function", ["softmax"])
44+
def test_topk_softmax_with_capacity(self, setup, score_function):
4245
# Test topk softmax with capacity
4346
logits = torch.randn(self.num_tokens, self.num_experts)
4447

4548
# Test without capacity
4649
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
4750
logits,
48-
topk=self.topk
51+
topk=self.topk,
52+
score_function=score_function
4953
)
5054
assert probs.shape == (self.num_tokens, self.num_experts)
5155
assert routing_map.shape == (self.num_tokens, self.num_experts)
5256
assert tokens_per_expert.shape == (self.num_experts,)
5357

54-
# Test with capacity
55-
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
56-
logits,
57-
topk=self.topk,
58-
capacity_factor=self.capacity_factor,
59-
pad_to_capacity=True
60-
)
61-
expert_capacity = get_capacity(
62-
num_tokens=self.num_tokens * self.topk,
63-
num_experts=self.num_experts,
64-
capacity_factor=self.capacity_factor
65-
)
66-
assert tokens_per_expert.max() <= expert_capacity
67-
6858
# Test with group routing
6959
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
7060
logits,
7161
topk=self.topk,
7262
num_groups=self.num_groups,
73-
group_topk=self.group_topk
63+
group_topk=self.group_topk,
64+
score_function=score_function
7465
)
7566
assert probs.shape == (self.num_tokens, self.num_experts)
7667

68+
7769
def test_get_capacity(self, setup):
7870
# Test capacity calculation
7971
capacity = get_capacity(
@@ -94,6 +86,7 @@ def test_get_capacity(self, setup):
9486
)
9587
assert capacity == min_capacity
9688

89+
9790
def test_permute(self, setup):
9891
# Test token permutation
9992
tokens = torch.randn(self.num_tokens, self.hidden_size)
@@ -120,6 +113,7 @@ def test_permute(self, setup):
120113
assert permuted_tokens.shape[0] == num_out_tokens
121114
assert sorted_indices.shape[0] == num_out_tokens
122115

116+
123117
def test_unpermute(self, setup):
124118
# Test token unpermutation
125119
tokens = torch.randn(self.num_tokens, self.hidden_size)
@@ -162,6 +156,7 @@ def test_unpermute(self, setup):
162156
)
163157
assert restored_tokens.shape == tokens.shape
164158

159+
165160
def test_sort_chunks_by_idxs(self, setup):
166161
# Test chunk sorting
167162
input_tensor = torch.randn(10, self.hidden_size)
@@ -173,10 +168,10 @@ def test_sort_chunks_by_idxs(self, setup):
173168

174169
# Verify the order is correct
175170
expected = torch.cat([input_tensor[5:], input_tensor[0: 3], input_tensor[3: 5]])
176-
assert torch.allclose(output, expected) \
177-
\
178-
@ pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
171+
assert torch.allclose(output, expected)
179172

173+
174+
@pytest.mark.parametrize("score_function", ["softmax"])
180175
def test_score_functions(self, setup, score_function):
181176
# Test different score functions
182177
logits = torch.randn(self.num_tokens, self.num_experts)
@@ -190,28 +185,4 @@ def test_score_functions(self, setup, score_function):
190185
)
191186
assert probs.shape == (self.num_tokens, self.num_experts)
192187
assert routing_map.shape == (self.num_tokens, self.num_experts)
193-
assert tokens_per_expert.shape == (self.num_experts,)
194-
195-
def test_edge_cases(self, setup):
196-
# Test empty input
197-
empty_logits = torch.randn(0, self.num_experts)
198-
with pytest.raises(AssertionError):
199-
topk_softmax_with_capacity(empty_logits, topk=self.topk)
200-
201-
# Test invalid score function
202-
logits = torch.randn(self.num_tokens, self.num_experts)
203-
with pytest.raises(ValueError):
204-
topk_softmax_with_capacity(
205-
logits,
206-
topk=self.topk,
207-
score_function="invalid"
208-
)
209-
210-
# Test invalid drop policy
211-
with pytest.raises(ValueError):
212-
topk_softmax_with_capacity(
213-
logits,
214-
topk=self.topk,
215-
capacity_factor=1.0,
216-
drop_policy="invalid"
217-
)
188+
assert tokens_per_expert.shape == (self.num_experts,)

tests/ut/test_token_dispatcher.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
4+
5+
import torch
6+
import pytest
7+
from pytest_mock import MockerFixture
8+
import vllm_ascend.patch.worker.patch_common.patch_utils
9+
from vllm_ascend.utils import adapt_patch # noqa E402
10+
11+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher
12+
13+
adapt_patch(True)
14+
15+
class TestMoEAlltoAllSeqOverLapDispatcher:
16+
17+
@pytest.fixture
18+
def config(self):
19+
config = MoeDispatcherConfig()
20+
config.set_num_local_experts(2)
21+
config.set_num_moe_experts(4)
22+
config.set_moe_pad_expert_input_to_capacity(False)
23+
config.set_moe_expert_capacity_factor(None)
24+
config.set_moe_router_topk(2)
25+
config.set_moe_grouped_gemm(False)
26+
config.set_group_topk(0)
27+
config.set_num_groups(1)
28+
config.set_is_fused(False)
29+
return config.build()
30+
31+
def mock_ep_group(self, mocker):
32+
mock_group = mocker.MagicMock()
33+
mock_group.rank_in_group = 0
34+
mock_group.world_size = 2
35+
mock_group.device_group = "mock_group"
36+
return mock_group
37+
38+
@pytest.fixture
39+
def dispatcher(self, config, mocker: MockerFixture):
40+
mocker.patch("vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
41+
return_value=self.mock_ep_group(mocker))
42+
return MoEAlltoAllSeqOverLapDispatcher(config)
43+
44+
def test_initialization(self, dispatcher, config):
45+
assert dispatcher.num_local_experts == config.num_local_experts
46+
assert dispatcher.num_experts == config.num_moe_experts
47+
assert dispatcher.local_expert_indices == [0, 1]
48+
assert dispatcher.ep_rank == 0
49+
assert dispatcher.ep_size == 2
50+
assert dispatcher.overlap_stream is not None
51+
52+
def test_routing(self, dispatcher):
53+
probs = torch.randn(4, 4) # 4 tokens, 4 experts
54+
scores, routing_map = dispatcher.routing(probs)
55+
assert scores.shape == (4, 4) # topk=2
56+
assert routing_map.shape == (4, 4)

0 commit comments

Comments
 (0)