1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
4
+ import torch
5
+ import pytest
6
+ import math
7
+
8
+ from vllm_ascend .ops .moe_dispatcher .moe_utils import permute , get_capacity , topk_softmax_with_capacity , \
9
+ group_limited_topk , unpermute , sort_chunks_by_idxs
10
+
11
+
12
+ class TestMoeUtils :
13
+
14
+ @pytest .fixture
15
+ def setup (self ):
16
+ self .num_tokens = 16
17
+ self .num_experts = 4
18
+ self .hidden_size = 8
19
+ self .topk = 2
20
+ self .capacity_factor = 1.0
21
+ self .group_topk = 2
22
+ self .num_groups = 2
23
+ self .scaling_factor = 1.0
24
+
25
+ def test_group_limited_topk (self , setup ):
26
+ # Test group-limited topk routing
27
+ scores = torch .randn (self .num_tokens , self .num_experts )
28
+ probs , indices = group_limited_topk (
29
+ scores ,
30
+ topk = self .topk ,
31
+ num_tokens = self .num_tokens ,
32
+ num_experts = self .num_experts ,
33
+ num_groups = self .num_groups ,
34
+ group_topk = self .group_topk
35
+ )
36
+
37
+ assert probs .shape == (self .num_tokens , self .topk )
38
+ assert indices .shape == (self .num_tokens , self .topk )
39
+ assert torch .all (indices < self .num_experts )
40
+
41
+ def test_topk_softmax_with_capacity (self , setup ):
42
+ # Test topk softmax with capacity
43
+ logits = torch .randn (self .num_tokens , self .num_experts )
44
+
45
+ # Test without capacity
46
+ probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
47
+ logits ,
48
+ topk = self .topk
49
+ )
50
+ assert probs .shape == (self .num_tokens , self .num_experts )
51
+ assert routing_map .shape == (self .num_tokens , self .num_experts )
52
+ assert tokens_per_expert .shape == (self .num_experts ,)
53
+
54
+ # Test with capacity
55
+ probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
56
+ logits ,
57
+ topk = self .topk ,
58
+ capacity_factor = self .capacity_factor ,
59
+ pad_to_capacity = True
60
+ )
61
+ expert_capacity = get_capacity (
62
+ num_tokens = self .num_tokens * self .topk ,
63
+ num_experts = self .num_experts ,
64
+ capacity_factor = self .capacity_factor
65
+ )
66
+ assert tokens_per_expert .max () <= expert_capacity
67
+
68
+ # Test with group routing
69
+ probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
70
+ logits ,
71
+ topk = self .topk ,
72
+ num_groups = self .num_groups ,
73
+ group_topk = self .group_topk
74
+ )
75
+ assert probs .shape == (self .num_tokens , self .num_experts )
76
+
77
+ def test_get_capacity (self , setup ):
78
+ # Test capacity calculation
79
+ capacity = get_capacity (
80
+ num_tokens = self .num_tokens ,
81
+ num_experts = self .num_experts ,
82
+ capacity_factor = self .capacity_factor
83
+ )
84
+ expected = math .ceil ((self .num_tokens / self .num_experts ) * self .capacity_factor )
85
+ assert capacity == expected
86
+
87
+ # Test with min capacity
88
+ min_capacity = 5
89
+ capacity = get_capacity (
90
+ num_tokens = self .num_tokens ,
91
+ num_experts = self .num_experts ,
92
+ capacity_factor = self .capacity_factor ,
93
+ min_capacity = min_capacity
94
+ )
95
+ assert capacity == min_capacity
96
+
97
+ def test_permute (self , setup ):
98
+ # Test token permutation
99
+ tokens = torch .randn (self .num_tokens , self .hidden_size )
100
+ routing_map = torch .randint (0 , 2 , (self .num_tokens , self .num_experts )).bool ()
101
+
102
+ # Basic permutation
103
+ permuted_tokens , sorted_indices = permute (tokens , routing_map )
104
+ assert permuted_tokens .shape [0 ] == routing_map .sum ()
105
+ assert sorted_indices .shape [0 ] == routing_map .sum ()
106
+
107
+ # With drop and pad
108
+ capacity = get_capacity (
109
+ num_tokens = self .num_tokens * self .topk ,
110
+ num_experts = self .num_experts ,
111
+ capacity_factor = self .capacity_factor
112
+ )
113
+ num_out_tokens = capacity * self .num_experts
114
+ permuted_tokens , sorted_indices = permute (
115
+ tokens ,
116
+ routing_map ,
117
+ num_out_tokens = num_out_tokens ,
118
+ drop_and_pad = True
119
+ )
120
+ assert permuted_tokens .shape [0 ] == num_out_tokens
121
+ assert sorted_indices .shape [0 ] == num_out_tokens
122
+
123
+ def test_unpermute (self , setup ):
124
+ # Test token unpermutation
125
+ tokens = torch .randn (self .num_tokens , self .hidden_size )
126
+ routing_map = torch .randint (0 , 2 , (self .num_tokens , self .num_experts )).bool ()
127
+ probs = torch .rand (self .num_tokens , self .num_experts )
128
+
129
+ # First permute
130
+ permuted_tokens , sorted_indices = permute (tokens , routing_map )
131
+
132
+ # Then unpermute
133
+ restored_tokens = unpermute (
134
+ permuted_tokens ,
135
+ sorted_indices ,
136
+ tokens .shape ,
137
+ probs = probs ,
138
+ routing_map = routing_map
139
+ )
140
+ assert restored_tokens .shape == tokens .shape
141
+
142
+ # With drop and pad
143
+ capacity = get_capacity (
144
+ num_tokens = self .num_tokens * self .topk ,
145
+ num_experts = self .num_experts ,
146
+ capacity_factor = self .capacity_factor
147
+ )
148
+ num_out_tokens = capacity * self .num_experts
149
+ permuted_tokens , sorted_indices = permute (
150
+ tokens ,
151
+ routing_map ,
152
+ num_out_tokens = num_out_tokens ,
153
+ drop_and_pad = True
154
+ )
155
+ restored_tokens = unpermute (
156
+ permuted_tokens ,
157
+ sorted_indices ,
158
+ tokens .shape ,
159
+ probs = probs ,
160
+ routing_map = routing_map ,
161
+ drop_and_pad = True
162
+ )
163
+ assert restored_tokens .shape == tokens .shape
164
+
165
+ def test_sort_chunks_by_idxs (self , setup ):
166
+ # Test chunk sorting
167
+ input_tensor = torch .randn (10 , self .hidden_size )
168
+ split_sizes = torch .tensor ([3 , 2 , 5 ])
169
+ sorted_idxs = torch .tensor ([2 , 0 , 1 ])
170
+
171
+ output = sort_chunks_by_idxs (input_tensor , split_sizes , sorted_idxs )
172
+ assert output .shape == input_tensor .shape
173
+
174
+ # Verify the order is correct
175
+ expected = torch .cat ([input_tensor [5 :], input_tensor [0 : 3 ], input_tensor [3 : 5 ]])
176
+ assert torch .allclose (output , expected ) \
177
+ \
178
+ @ pytest .mark .parametrize ("score_function" , ["softmax" , "sigmoid" ])
179
+
180
+ def test_score_functions (self , setup , score_function ):
181
+ # Test different score functions
182
+ logits = torch .randn (self .num_tokens , self .num_experts )
183
+ expert_bias = torch .randn (self .num_experts )
184
+
185
+ probs , routing_map , tokens_per_expert , top_indices = topk_softmax_with_capacity (
186
+ logits ,
187
+ topk = self .topk ,
188
+ score_function = score_function ,
189
+ expert_bias = expert_bias
190
+ )
191
+ assert probs .shape == (self .num_tokens , self .num_experts )
192
+ assert routing_map .shape == (self .num_tokens , self .num_experts )
193
+ assert tokens_per_expert .shape == (self .num_experts ,)
194
+
195
+ def test_edge_cases (self , setup ):
196
+ # Test empty input
197
+ empty_logits = torch .randn (0 , self .num_experts )
198
+ with pytest .raises (AssertionError ):
199
+ topk_softmax_with_capacity (empty_logits , topk = self .topk )
200
+
201
+ # Test invalid score function
202
+ logits = torch .randn (self .num_tokens , self .num_experts )
203
+ with pytest .raises (ValueError ):
204
+ topk_softmax_with_capacity (
205
+ logits ,
206
+ topk = self .topk ,
207
+ score_function = "invalid"
208
+ )
209
+
210
+ # Test invalid drop policy
211
+ with pytest .raises (ValueError ):
212
+ topk_softmax_with_capacity (
213
+ logits ,
214
+ topk = self .topk ,
215
+ capacity_factor = 1.0 ,
216
+ drop_policy = "invalid"
217
+ )
0 commit comments