4
4
import torch
5
5
import pytest
6
6
import math
7
- import vllm_ascend .patch .worker .patch_common .patch_utils
7
+ import vllm_ascend .patch .worker .patch_common .patch_utils # type: ignore[import] # isort: skip # noqa
8
8
9
9
from vllm_ascend .ops .moe_dispatcher .moe_utils import permute , get_capacity , topk_softmax_with_capacity , group_limited_topk , unpermute , sort_chunks_by_idxs
10
10
@@ -22,141 +22,118 @@ def setup(self):
22
22
self .num_groups = 2
23
23
self .scaling_factor = 1.0
24
24
25
-
26
25
def test_group_limited_topk (self , setup ):
27
26
# Test group-limited topk routing
28
27
scores = torch .randn (self .num_tokens , self .num_experts )
29
- probs , indices = group_limited_topk (
30
- scores ,
31
- topk = self .topk ,
32
- num_tokens = self .num_tokens ,
33
- num_experts = self .num_experts ,
34
- num_groups = self .num_groups ,
35
- group_topk = self .group_topk
36
- )
28
+ probs , indices = group_limited_topk (scores ,
29
+ topk = self .topk ,
30
+ num_tokens = self .num_tokens ,
31
+ num_experts = self .num_experts ,
32
+ num_groups = self .num_groups ,
33
+ group_topk = self .group_topk )
37
34
38
35
assert probs .shape == (self .num_tokens , self .topk )
39
36
assert indices .shape == (self .num_tokens , self .topk )
40
37
assert torch .all (indices < self .num_experts )
41
38
42
-
43
39
@pytest .mark .parametrize ("score_function" , ["softmax" ])
44
40
def test_topk_softmax_with_capacity (self , setup , score_function ):
45
41
# Test topk softmax with capacity
46
42
logits = torch .randn (self .num_tokens , self .num_experts )
47
43
48
44
# Test without capacity
49
45
probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
50
- logits ,
51
- topk = self .topk ,
52
- score_function = score_function
53
- )
46
+ logits , topk = self .topk , score_function = score_function )
54
47
assert probs .shape == (self .num_tokens , self .num_experts )
55
48
assert routing_map .shape == (self .num_tokens , self .num_experts )
56
- assert tokens_per_expert .shape == (self .num_experts ,)
49
+ assert tokens_per_expert .shape == (self .num_experts , )
57
50
58
51
# Test with group routing
59
52
probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
60
53
logits ,
61
54
topk = self .topk ,
62
55
num_groups = self .num_groups ,
63
56
group_topk = self .group_topk ,
64
- score_function = score_function
65
- )
57
+ score_function = score_function )
66
58
assert probs .shape == (self .num_tokens , self .num_experts )
67
59
68
-
69
60
def test_get_capacity (self , setup ):
70
61
# Test capacity calculation
71
- capacity = get_capacity (
72
- num_tokens = self .num_tokens ,
73
- num_experts = self .num_experts ,
74
- capacity_factor = self .capacity_factor
75
- )
76
- expected = math .ceil ((self .num_tokens / self .num_experts ) * self .capacity_factor )
62
+ capacity = get_capacity (num_tokens = self .num_tokens ,
63
+ num_experts = self .num_experts ,
64
+ capacity_factor = self .capacity_factor )
65
+ expected = math .ceil (
66
+ (self .num_tokens / self .num_experts ) * self .capacity_factor )
77
67
assert capacity == expected
78
68
79
69
# Test with min capacity
80
70
min_capacity = 5
81
- capacity = get_capacity (
82
- num_tokens = self .num_tokens ,
83
- num_experts = self .num_experts ,
84
- capacity_factor = self .capacity_factor ,
85
- min_capacity = min_capacity
86
- )
71
+ capacity = get_capacity (num_tokens = self .num_tokens ,
72
+ num_experts = self .num_experts ,
73
+ capacity_factor = self .capacity_factor ,
74
+ min_capacity = min_capacity )
87
75
assert capacity == min_capacity
88
76
89
-
90
77
def test_permute (self , setup ):
91
78
# Test token permutation
92
79
tokens = torch .randn (self .num_tokens , self .hidden_size )
93
- routing_map = torch .randint (0 , 2 , (self .num_tokens , self .num_experts )).bool ()
80
+ routing_map = torch .randint (
81
+ 0 , 2 , (self .num_tokens , self .num_experts )).bool ()
94
82
95
83
# Basic permutation
96
84
permuted_tokens , sorted_indices = permute (tokens , routing_map )
97
85
assert permuted_tokens .shape [0 ] == routing_map .sum ()
98
86
assert sorted_indices .shape [0 ] == routing_map .sum ()
99
87
100
88
# With drop and pad
101
- capacity = get_capacity (
102
- num_tokens = self .num_tokens * self .topk ,
103
- num_experts = self .num_experts ,
104
- capacity_factor = self .capacity_factor
105
- )
89
+ capacity = get_capacity (num_tokens = self .num_tokens * self .topk ,
90
+ num_experts = self .num_experts ,
91
+ capacity_factor = self .capacity_factor )
106
92
num_out_tokens = capacity * self .num_experts
107
93
permuted_tokens , sorted_indices = permute (
108
94
tokens ,
109
95
routing_map ,
110
96
num_out_tokens = num_out_tokens ,
111
- drop_and_pad = True
112
- )
97
+ drop_and_pad = True )
113
98
assert permuted_tokens .shape [0 ] == num_out_tokens
114
99
assert sorted_indices .shape [0 ] == num_out_tokens
115
100
116
-
117
101
def test_unpermute (self , setup ):
118
102
# Test token unpermutation
119
103
tokens = torch .randn (self .num_tokens , self .hidden_size )
120
- routing_map = torch .randint (0 , 2 , (self .num_tokens , self .num_experts )).bool ()
104
+ routing_map = torch .randint (
105
+ 0 , 2 , (self .num_tokens , self .num_experts )).bool ()
121
106
probs = torch .rand (self .num_tokens , self .num_experts )
122
107
123
108
# First permute
124
109
permuted_tokens , sorted_indices = permute (tokens , routing_map )
125
110
126
111
# Then unpermute
127
- restored_tokens = unpermute (
128
- permuted_tokens ,
129
- sorted_indices ,
130
- tokens .shape ,
131
- probs = probs ,
132
- routing_map = routing_map
133
- )
112
+ restored_tokens = unpermute (permuted_tokens ,
113
+ sorted_indices ,
114
+ tokens .shape ,
115
+ probs = probs ,
116
+ routing_map = routing_map )
134
117
assert restored_tokens .shape == tokens .shape
135
118
136
119
# With drop and pad
137
- capacity = get_capacity (
138
- num_tokens = self .num_tokens * self .topk ,
139
- num_experts = self .num_experts ,
140
- capacity_factor = self .capacity_factor
141
- )
120
+ capacity = get_capacity (num_tokens = self .num_tokens * self .topk ,
121
+ num_experts = self .num_experts ,
122
+ capacity_factor = self .capacity_factor )
142
123
num_out_tokens = capacity * self .num_experts
143
124
permuted_tokens , sorted_indices = permute (
144
125
tokens ,
145
126
routing_map ,
146
127
num_out_tokens = num_out_tokens ,
147
- drop_and_pad = True
148
- )
149
- restored_tokens = unpermute (
150
- permuted_tokens ,
151
- sorted_indices ,
152
- tokens .shape ,
153
- probs = probs ,
154
- routing_map = routing_map ,
155
- drop_and_pad = True
156
- )
128
+ drop_and_pad = True )
129
+ restored_tokens = unpermute (permuted_tokens ,
130
+ sorted_indices ,
131
+ tokens .shape ,
132
+ probs = probs ,
133
+ routing_map = routing_map ,
134
+ drop_and_pad = True )
157
135
assert restored_tokens .shape == tokens .shape
158
136
159
-
160
137
def test_sort_chunks_by_idxs (self , setup ):
161
138
# Test chunk sorting
162
139
input_tensor = torch .randn (10 , self .hidden_size )
@@ -167,10 +144,10 @@ def test_sort_chunks_by_idxs(self, setup):
167
144
assert output .shape == input_tensor .shape
168
145
169
146
# Verify the order is correct
170
- expected = torch .cat ([input_tensor [5 :], input_tensor [0 : 3 ], input_tensor [3 : 5 ]])
147
+ expected = torch .cat (
148
+ [input_tensor [5 :], input_tensor [0 :3 ], input_tensor [3 :5 ]])
171
149
assert torch .allclose (output , expected )
172
150
173
-
174
151
@pytest .mark .parametrize ("score_function" , ["softmax" ])
175
152
def test_score_functions (self , setup , score_function ):
176
153
# Test different score functions
@@ -181,8 +158,7 @@ def test_score_functions(self, setup, score_function):
181
158
logits ,
182
159
topk = self .topk ,
183
160
score_function = score_function ,
184
- expert_bias = expert_bias
185
- )
161
+ expert_bias = expert_bias )
186
162
assert probs .shape == (self .num_tokens , self .num_experts )
187
163
assert routing_map .shape == (self .num_tokens , self .num_experts )
188
- assert tokens_per_expert .shape == (self .num_experts ,)
164
+ assert tokens_per_expert .shape == (self .num_experts , )
0 commit comments