Skip to content

Commit 3f88769

Browse files
author
weijinqian_v1
committed
add st for moe token dispatcher
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 5411ed6 commit 3f88769

File tree

2 files changed

+422
-0
lines changed

2 files changed

+422
-0
lines changed

tests/ut/moe_util.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
4+
import torch
5+
import pytest
6+
import math
7+
8+
from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, \
9+
group_limited_topk, unpermute, sort_chunks_by_idxs
10+
11+
12+
class TestMoeUtils:
13+
14+
@pytest.fixture
15+
def setup(self):
16+
self.num_tokens = 16
17+
self.num_experts = 4
18+
self.hidden_size = 8
19+
self.topk = 2
20+
self.capacity_factor = 1.0
21+
self.group_topk = 2
22+
self.num_groups = 2
23+
self.scaling_factor = 1.0
24+
25+
def test_group_limited_topk(self, setup):
26+
# Test group-limited topk routing
27+
scores = torch.randn(self.num_tokens, self.num_experts)
28+
probs, indices = group_limited_topk(
29+
scores,
30+
topk=self.topk,
31+
num_tokens=self.num_tokens,
32+
num_experts=self.num_experts,
33+
num_groups=self.num_groups,
34+
group_topk=self.group_topk
35+
)
36+
37+
assert probs.shape == (self.num_tokens, self.topk)
38+
assert indices.shape == (self.num_tokens, self.topk)
39+
assert torch.all(indices < self.num_experts)
40+
41+
def test_topk_softmax_with_capacity(self, setup):
42+
# Test topk softmax with capacity
43+
logits = torch.randn(self.num_tokens, self.num_experts)
44+
45+
# Test without capacity
46+
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
47+
logits,
48+
topk=self.topk
49+
)
50+
assert probs.shape == (self.num_tokens, self.num_experts)
51+
assert routing_map.shape == (self.num_tokens, self.num_experts)
52+
assert tokens_per_expert.shape == (self.num_experts,)
53+
54+
# Test with capacity
55+
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
56+
logits,
57+
topk=self.topk,
58+
capacity_factor=self.capacity_factor,
59+
pad_to_capacity=True
60+
)
61+
expert_capacity = get_capacity(
62+
num_tokens=self.num_tokens * self.topk,
63+
num_experts=self.num_experts,
64+
capacity_factor=self.capacity_factor
65+
)
66+
assert tokens_per_expert.max() <= expert_capacity
67+
68+
# Test with group routing
69+
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
70+
logits,
71+
topk=self.topk,
72+
num_groups=self.num_groups,
73+
group_topk=self.group_topk
74+
)
75+
assert probs.shape == (self.num_tokens, self.num_experts)
76+
77+
def test_get_capacity(self, setup):
78+
# Test capacity calculation
79+
capacity = get_capacity(
80+
num_tokens=self.num_tokens,
81+
num_experts=self.num_experts,
82+
capacity_factor=self.capacity_factor
83+
)
84+
expected = math.ceil((self.num_tokens / self.num_experts) * self.capacity_factor)
85+
assert capacity == expected
86+
87+
# Test with min capacity
88+
min_capacity = 5
89+
capacity = get_capacity(
90+
num_tokens=self.num_tokens,
91+
num_experts=self.num_experts,
92+
capacity_factor=self.capacity_factor,
93+
min_capacity=min_capacity
94+
)
95+
assert capacity == min_capacity
96+
97+
def test_permute(self, setup):
98+
# Test token permutation
99+
tokens = torch.randn(self.num_tokens, self.hidden_size)
100+
routing_map = torch.randint(0, 2, (self.num_tokens, self.num_experts)).bool()
101+
102+
# Basic permutation
103+
permuted_tokens, sorted_indices = permute(tokens, routing_map)
104+
assert permuted_tokens.shape[0] == routing_map.sum()
105+
assert sorted_indices.shape[0] == routing_map.sum()
106+
107+
# With drop and pad
108+
capacity = get_capacity(
109+
num_tokens=self.num_tokens * self.topk,
110+
num_experts=self.num_experts,
111+
capacity_factor=self.capacity_factor
112+
)
113+
num_out_tokens = capacity * self.num_experts
114+
permuted_tokens, sorted_indices = permute(
115+
tokens,
116+
routing_map,
117+
num_out_tokens=num_out_tokens,
118+
drop_and_pad=True
119+
)
120+
assert permuted_tokens.shape[0] == num_out_tokens
121+
assert sorted_indices.shape[0] == num_out_tokens
122+
123+
def test_unpermute(self, setup):
124+
# Test token unpermutation
125+
tokens = torch.randn(self.num_tokens, self.hidden_size)
126+
routing_map = torch.randint(0, 2, (self.num_tokens, self.num_experts)).bool()
127+
probs = torch.rand(self.num_tokens, self.num_experts)
128+
129+
# First permute
130+
permuted_tokens, sorted_indices = permute(tokens, routing_map)
131+
132+
# Then unpermute
133+
restored_tokens = unpermute(
134+
permuted_tokens,
135+
sorted_indices,
136+
tokens.shape,
137+
probs=probs,
138+
routing_map=routing_map
139+
)
140+
assert restored_tokens.shape == tokens.shape
141+
142+
# With drop and pad
143+
capacity = get_capacity(
144+
num_tokens=self.num_tokens * self.topk,
145+
num_experts=self.num_experts,
146+
capacity_factor=self.capacity_factor
147+
)
148+
num_out_tokens = capacity * self.num_experts
149+
permuted_tokens, sorted_indices = permute(
150+
tokens,
151+
routing_map,
152+
num_out_tokens=num_out_tokens,
153+
drop_and_pad=True
154+
)
155+
restored_tokens = unpermute(
156+
permuted_tokens,
157+
sorted_indices,
158+
tokens.shape,
159+
probs=probs,
160+
routing_map=routing_map,
161+
drop_and_pad=True
162+
)
163+
assert restored_tokens.shape == tokens.shape
164+
165+
def test_sort_chunks_by_idxs(self, setup):
166+
# Test chunk sorting
167+
input_tensor = torch.randn(10, self.hidden_size)
168+
split_sizes = torch.tensor([3, 2, 5])
169+
sorted_idxs = torch.tensor([2, 0, 1])
170+
171+
output = sort_chunks_by_idxs(input_tensor, split_sizes, sorted_idxs)
172+
assert output.shape == input_tensor.shape
173+
174+
# Verify the order is correct
175+
expected = torch.cat([input_tensor[5:], input_tensor[0: 3], input_tensor[3: 5]])
176+
assert torch.allclose(output, expected) \
177+
\
178+
@ pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
179+
180+
def test_score_functions(self, setup, score_function):
181+
# Test different score functions
182+
logits = torch.randn(self.num_tokens, self.num_experts)
183+
expert_bias = torch.randn(self.num_experts)
184+
185+
probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity(
186+
logits,
187+
topk=self.topk,
188+
score_function=score_function,
189+
expert_bias=expert_bias
190+
)
191+
assert probs.shape == (self.num_tokens, self.num_experts)
192+
assert routing_map.shape == (self.num_tokens, self.num_experts)
193+
assert tokens_per_expert.shape == (self.num_experts,)
194+
195+
def test_edge_cases(self, setup):
196+
# Test empty input
197+
empty_logits = torch.randn(0, self.num_experts)
198+
with pytest.raises(AssertionError):
199+
topk_softmax_with_capacity(empty_logits, topk=self.topk)
200+
201+
# Test invalid score function
202+
logits = torch.randn(self.num_tokens, self.num_experts)
203+
with pytest.raises(ValueError):
204+
topk_softmax_with_capacity(
205+
logits,
206+
topk=self.topk,
207+
score_function="invalid"
208+
)
209+
210+
# Test invalid drop policy
211+
with pytest.raises(ValueError):
212+
topk_softmax_with_capacity(
213+
logits,
214+
topk=self.topk,
215+
capacity_factor=1.0,
216+
drop_policy="invalid"
217+
)

0 commit comments

Comments
 (0)