Skip to content

Commit d0bd006

Browse files
author
weijinqian_v1
committed
add st for moe token dispatcher
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 854c149 commit d0bd006

File tree

3 files changed

+74
-252
lines changed

3 files changed

+74
-252
lines changed

tests/ut/moe_util.py renamed to tests/ut/test_moe_util.py

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch
55
import pytest
66
import math
7+
import vllm_ascend.patch.worker.patch_common.patch_utils
78

8-
from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, \
9-
group_limited_topk, unpermute, sort_chunks_by_idxs
9+
from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, group_limited_topk, unpermute, sort_chunks_by_idxs
1010

1111

1212
class TestMoeUtils:
@@ -22,6 +22,7 @@ def setup(self):
2222
self.num_groups = 2
2323
self.scaling_factor = 1.0
2424

25+
2526
def test_group_limited_topk(self, setup):
2627
# Test group-limited topk routing
2728
scores = torch.randn(self.num_tokens, self.num_experts)
@@ -38,42 +39,33 @@ def test_group_limited_topk(self, setup):
3839
assert indices.shape == (self.num_tokens, self.topk)
3940
assert torch.all(indices < self.num_experts)
4041

41-
def test_topk_softmax_with_capacity(self, setup):
42+
43+
@pytest.mark.parametrize("score_function", ["softmax"])
44+
def test_topk_softmax_with_capacity(self, setup, score_function):
4245
# Test topk softmax with capacity
4346
logits = torch.randn(self.num_tokens, self.num_experts)
4447

4548
# Test without capacity
4649
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
4750
logits,
48-
topk=self.topk
51+
topk=self.topk,
52+
score_function=score_function
4953
)
5054
assert probs.shape == (self.num_tokens, self.num_experts)
5155
assert routing_map.shape == (self.num_tokens, self.num_experts)
5256
assert tokens_per_expert.shape == (self.num_experts,)
5357

54-
# Test with capacity
55-
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
56-
logits,
57-
topk=self.topk,
58-
capacity_factor=self.capacity_factor,
59-
pad_to_capacity=True
60-
)
61-
expert_capacity = get_capacity(
62-
num_tokens=self.num_tokens * self.topk,
63-
num_experts=self.num_experts,
64-
capacity_factor=self.capacity_factor
65-
)
66-
assert tokens_per_expert.max() <= expert_capacity
67-
6858
# Test with group routing
6959
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
7060
logits,
7161
topk=self.topk,
7262
num_groups=self.num_groups,
73-
group_topk=self.group_topk
63+
group_topk=self.group_topk,
64+
score_function=score_function
7465
)
7566
assert probs.shape == (self.num_tokens, self.num_experts)
7667

68+
7769
def test_get_capacity(self, setup):
7870
# Test capacity calculation
7971
capacity = get_capacity(
@@ -94,6 +86,7 @@ def test_get_capacity(self, setup):
9486
)
9587
assert capacity == min_capacity
9688

89+
9790
def test_permute(self, setup):
9891
# Test token permutation
9992
tokens = torch.randn(self.num_tokens, self.hidden_size)
@@ -120,6 +113,7 @@ def test_permute(self, setup):
120113
assert permuted_tokens.shape[0] == num_out_tokens
121114
assert sorted_indices.shape[0] == num_out_tokens
122115

116+
123117
def test_unpermute(self, setup):
124118
# Test token unpermutation
125119
tokens = torch.randn(self.num_tokens, self.hidden_size)
@@ -162,6 +156,7 @@ def test_unpermute(self, setup):
162156
)
163157
assert restored_tokens.shape == tokens.shape
164158

159+
165160
def test_sort_chunks_by_idxs(self, setup):
166161
# Test chunk sorting
167162
input_tensor = torch.randn(10, self.hidden_size)
@@ -173,10 +168,10 @@ def test_sort_chunks_by_idxs(self, setup):
173168

174169
# Verify the order is correct
175170
expected = torch.cat([input_tensor[5:], input_tensor[0: 3], input_tensor[3: 5]])
176-
assert torch.allclose(output, expected) \
177-
\
178-
@ pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
171+
assert torch.allclose(output, expected)
179172

173+
174+
@pytest.mark.parametrize("score_function", ["softmax"])
180175
def test_score_functions(self, setup, score_function):
181176
# Test different score functions
182177
logits = torch.randn(self.num_tokens, self.num_experts)
@@ -190,28 +185,4 @@ def test_score_functions(self, setup, score_function):
190185
)
191186
assert probs.shape == (self.num_tokens, self.num_experts)
192187
assert routing_map.shape == (self.num_tokens, self.num_experts)
193-
assert tokens_per_expert.shape == (self.num_experts,)
194-
195-
def test_edge_cases(self, setup):
196-
# Test empty input
197-
empty_logits = torch.randn(0, self.num_experts)
198-
with pytest.raises(AssertionError):
199-
topk_softmax_with_capacity(empty_logits, topk=self.topk)
200-
201-
# Test invalid score function
202-
logits = torch.randn(self.num_tokens, self.num_experts)
203-
with pytest.raises(ValueError):
204-
topk_softmax_with_capacity(
205-
logits,
206-
topk=self.topk,
207-
score_function="invalid"
208-
)
209-
210-
# Test invalid drop policy
211-
with pytest.raises(ValueError):
212-
topk_softmax_with_capacity(
213-
logits,
214-
topk=self.topk,
215-
capacity_factor=1.0,
216-
drop_policy="invalid"
217-
)
188+
assert tokens_per_expert.shape == (self.num_experts,)

tests/ut/test_token_dispatcher.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
4+
5+
import torch
6+
import pytest
7+
from pytest_mock import MockerFixture
8+
import vllm_ascend.patch.worker.patch_common.patch_utils
9+
from vllm_ascend.utils import adapt_patch # noqa E402
10+
11+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher
12+
13+
adapt_patch(True)
14+
15+
class TestMoEAlltoAllSeqOverLapDispatcher:
16+
17+
@pytest.fixture
18+
def config(self):
19+
config = MoeDispatcherConfig()
20+
config.set_num_local_experts(2)
21+
config.set_num_moe_experts(4)
22+
config.set_moe_pad_expert_input_to_capacity(False)
23+
config.set_moe_expert_capacity_factor(None)
24+
config.set_moe_router_topk(2)
25+
config.set_moe_grouped_gemm(False)
26+
config.set_group_topk(0)
27+
config.set_num_groups(1)
28+
config.set_is_fused(False)
29+
return config.build()
30+
31+
def mock_ep_group(self, mocker):
32+
mock_group = mocker.MagicMock()
33+
mock_group.rank_in_group = 0
34+
mock_group.world_size = 2
35+
mock_group.device_group = "mock_group"
36+
return mock_group
37+
38+
@pytest.fixture
39+
def dispatcher(self, config, mocker: MockerFixture):
40+
mocker.patch("vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
41+
return_value=self.mock_ep_group(mocker))
42+
return MoEAlltoAllSeqOverLapDispatcher(config)
43+
44+
def test_initialization(self, dispatcher, config):
45+
assert dispatcher.num_local_experts == config.num_local_experts
46+
assert dispatcher.num_experts == config.num_moe_experts
47+
assert dispatcher.local_expert_indices == [0, 1]
48+
assert dispatcher.ep_rank == 0
49+
assert dispatcher.ep_size == 2
50+
assert dispatcher.overlap_stream is not None
51+
52+
def test_routing(self, dispatcher):
53+
probs = torch.randn(4, 4) # 4 tokens, 4 experts
54+
scores, routing_map = dispatcher.routing(probs)
55+
assert scores.shape == (4, 4) # topk=2
56+
assert routing_map.shape == (4, 4)

0 commit comments

Comments
 (0)