4
4
import torch
5
5
import pytest
6
6
import math
7
+ import vllm_ascend .patch .worker .patch_common .patch_utils
7
8
8
- from vllm_ascend .ops .moe_dispatcher .moe_utils import permute , get_capacity , topk_softmax_with_capacity , \
9
- group_limited_topk , unpermute , sort_chunks_by_idxs
9
+ from vllm_ascend .ops .moe_dispatcher .moe_utils import permute , get_capacity , topk_softmax_with_capacity , group_limited_topk , unpermute , sort_chunks_by_idxs
10
10
11
11
12
12
class TestMoeUtils :
@@ -22,6 +22,7 @@ def setup(self):
22
22
self .num_groups = 2
23
23
self .scaling_factor = 1.0
24
24
25
+
25
26
def test_group_limited_topk (self , setup ):
26
27
# Test group-limited topk routing
27
28
scores = torch .randn (self .num_tokens , self .num_experts )
@@ -38,42 +39,33 @@ def test_group_limited_topk(self, setup):
38
39
assert indices .shape == (self .num_tokens , self .topk )
39
40
assert torch .all (indices < self .num_experts )
40
41
41
- def test_topk_softmax_with_capacity (self , setup ):
42
+
43
+ @pytest .mark .parametrize ("score_function" , ["softmax" ])
44
+ def test_topk_softmax_with_capacity (self , setup , score_function ):
42
45
# Test topk softmax with capacity
43
46
logits = torch .randn (self .num_tokens , self .num_experts )
44
47
45
48
# Test without capacity
46
49
probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
47
50
logits ,
48
- topk = self .topk
51
+ topk = self .topk ,
52
+ score_function = score_function
49
53
)
50
54
assert probs .shape == (self .num_tokens , self .num_experts )
51
55
assert routing_map .shape == (self .num_tokens , self .num_experts )
52
56
assert tokens_per_expert .shape == (self .num_experts ,)
53
57
54
- # Test with capacity
55
- probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
56
- logits ,
57
- topk = self .topk ,
58
- capacity_factor = self .capacity_factor ,
59
- pad_to_capacity = True
60
- )
61
- expert_capacity = get_capacity (
62
- num_tokens = self .num_tokens * self .topk ,
63
- num_experts = self .num_experts ,
64
- capacity_factor = self .capacity_factor
65
- )
66
- assert tokens_per_expert .max () <= expert_capacity
67
-
68
58
# Test with group routing
69
59
probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
70
60
logits ,
71
61
topk = self .topk ,
72
62
num_groups = self .num_groups ,
73
- group_topk = self .group_topk
63
+ group_topk = self .group_topk ,
64
+ score_function = score_function
74
65
)
75
66
assert probs .shape == (self .num_tokens , self .num_experts )
76
67
68
+
77
69
def test_get_capacity (self , setup ):
78
70
# Test capacity calculation
79
71
capacity = get_capacity (
@@ -94,6 +86,7 @@ def test_get_capacity(self, setup):
94
86
)
95
87
assert capacity == min_capacity
96
88
89
+
97
90
def test_permute (self , setup ):
98
91
# Test token permutation
99
92
tokens = torch .randn (self .num_tokens , self .hidden_size )
@@ -120,6 +113,7 @@ def test_permute(self, setup):
120
113
assert permuted_tokens .shape [0 ] == num_out_tokens
121
114
assert sorted_indices .shape [0 ] == num_out_tokens
122
115
116
+
123
117
def test_unpermute (self , setup ):
124
118
# Test token unpermutation
125
119
tokens = torch .randn (self .num_tokens , self .hidden_size )
@@ -162,6 +156,7 @@ def test_unpermute(self, setup):
162
156
)
163
157
assert restored_tokens .shape == tokens .shape
164
158
159
+
165
160
def test_sort_chunks_by_idxs (self , setup ):
166
161
# Test chunk sorting
167
162
input_tensor = torch .randn (10 , self .hidden_size )
@@ -173,10 +168,10 @@ def test_sort_chunks_by_idxs(self, setup):
173
168
174
169
# Verify the order is correct
175
170
expected = torch .cat ([input_tensor [5 :], input_tensor [0 : 3 ], input_tensor [3 : 5 ]])
176
- assert torch .allclose (output , expected ) \
177
- \
178
- @ pytest .mark .parametrize ("score_function" , ["softmax" , "sigmoid" ])
171
+ assert torch .allclose (output , expected )
179
172
173
+
174
+ @pytest .mark .parametrize ("score_function" , ["softmax" ])
180
175
def test_score_functions (self , setup , score_function ):
181
176
# Test different score functions
182
177
logits = torch .randn (self .num_tokens , self .num_experts )
@@ -190,28 +185,4 @@ def test_score_functions(self, setup, score_function):
190
185
)
191
186
assert probs .shape == (self .num_tokens , self .num_experts )
192
187
assert routing_map .shape == (self .num_tokens , self .num_experts )
193
- assert tokens_per_expert .shape == (self .num_experts ,)
194
-
195
- def test_edge_cases (self , setup ):
196
- # Test empty input
197
- empty_logits = torch .randn (0 , self .num_experts )
198
- with pytest .raises (AssertionError ):
199
- topk_softmax_with_capacity (empty_logits , topk = self .topk )
200
-
201
- # Test invalid score function
202
- logits = torch .randn (self .num_tokens , self .num_experts )
203
- with pytest .raises (ValueError ):
204
- topk_softmax_with_capacity (
205
- logits ,
206
- topk = self .topk ,
207
- score_function = "invalid"
208
- )
209
-
210
- # Test invalid drop policy
211
- with pytest .raises (ValueError ):
212
- topk_softmax_with_capacity (
213
- logits ,
214
- topk = self .topk ,
215
- capacity_factor = 1.0 ,
216
- drop_policy = "invalid"
217
- )
188
+ assert tokens_per_expert .shape == (self .num_experts ,)
0 commit comments