10
10
11
11
from tests .kernels .moe .utils import (batched_moe ,
12
12
make_quantized_test_activations ,
13
- make_test_weights , triton_moe )
13
+ make_test_weights , naive_batched_moe )
14
14
from tests .kernels .quant_utils import native_batched_masked_quant_matmul
15
15
from tests .kernels .utils import torch_experts
16
16
from vllm .config import VllmConfig , set_current_vllm_config
33
33
(45 , 512 , 512 ),
34
34
(45 , 1024 , 128 ),
35
35
(45 , 1024 , 2048 ),
36
- (64 , 128 , 128 ),
37
36
(64 , 512 , 512 ),
38
37
(64 , 1024 , 2048 ),
39
38
(222 , 128 , 128 ),
40
39
(222 , 128 , 2048 ),
41
- (222 , 512 , 512 ),
42
40
(222 , 1024 , 128 ),
43
41
(222 , 1024 , 2048 ),
44
42
]
@@ -95,11 +93,12 @@ def make_tensors(config: BatchedMMConfig):
95
93
@pytest .mark .parametrize ("max_tokens_per_expert" ,
96
94
[32 , 64 , 128 , 192 , 224 , 256 , 512 ])
97
95
@pytest .mark .parametrize ("K" , [128 , 256 , 1024 ])
98
- @pytest .mark .parametrize ("N" , [128 , 256 , 512 , 1024 ])
99
- @pytest .mark .parametrize ("dtype" ,
100
- [torch .float32 , torch .float16 , torch .bfloat16 ])
101
- @pytest .mark .parametrize ("block_shape" , [None ])
102
- @pytest .mark .parametrize ("per_act_token_quant" , [False ])
96
+ @pytest .mark .parametrize ("N" , [128 , 256 , 1024 ])
97
+ @pytest .mark .parametrize (
98
+ "dtype" ,
99
+ [torch .float8_e4m3fn , torch .float32 , torch .float16 , torch .bfloat16 ])
100
+ @pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
101
+ @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
103
102
def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
104
103
N : int , dtype : torch .dtype ,
105
104
block_shape : Optional [list [int ]],
@@ -134,7 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
134
133
in_dtype = act_dtype ,
135
134
quant_dtype = quant_dtype ,
136
135
block_shape = block_shape ,
137
- per_act_token_quant = per_act_token_quant )
136
+ per_act_token_quant = per_act_token_quant ,
137
+ )
138
138
139
139
B , B_q , B_scale , _ , _ , _ = make_test_weights (
140
140
num_experts ,
@@ -143,6 +143,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
143
143
in_dtype = act_dtype ,
144
144
quant_dtype = quant_dtype ,
145
145
block_shape = block_shape ,
146
+ per_act_token_quant = per_act_token_quant ,
146
147
)
147
148
148
149
out_shape = (num_experts , max_tokens_per_expert , N )
@@ -177,6 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
177
178
"BLOCK_SIZE_N" : 16 ,
178
179
"BLOCK_SIZE_K" : 16 if dtype .itemsize > 1 else 32
179
180
},
181
+ per_act_token_quant = per_act_token_quant ,
180
182
block_shape = block_shape ,
181
183
)
182
184
@@ -185,32 +187,31 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
185
187
B ,
186
188
ref_output ,
187
189
num_expert_tokens ,
188
- None ,
189
- None ,
190
- None ,
191
190
)
192
191
193
192
q_ref_output = native_batched_masked_quant_matmul (A_q , B_q , q_ref_output ,
194
193
num_expert_tokens ,
195
194
A_scale , B_scale ,
196
- block_shape )
195
+ block_shape ,
196
+ per_act_token_quant )
197
197
198
198
rtol , atol = {
199
199
torch .float16 : (6e-2 , 6e-2 ),
200
200
torch .bfloat16 : (6e-2 , 6e-2 ),
201
201
torch .float32 : (1e-2 , 1e-2 ),
202
202
}[test_output .dtype ]
203
203
204
- torch .testing .assert_close (ref_output , test_output , atol = atol , rtol = rtol )
204
+ torch .testing .assert_close (ref_output , q_ref_output , atol = atol , rtol = rtol )
205
205
torch .testing .assert_close (test_output , q_ref_output , atol = atol , rtol = rtol )
206
206
207
207
208
208
@pytest .mark .parametrize (("m" , "n" , "k" ), MNK_FACTORS )
209
209
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
210
210
@pytest .mark .parametrize ("topk" , TOP_KS )
211
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
212
- @pytest .mark .parametrize ("per_act_token_quant" , [False ])
213
- @pytest .mark .parametrize ("block_shape" , [None ])
211
+ @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
212
+ @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
213
+ @pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
214
+ @pytest .mark .parametrize ("input_scales" , [False ])
214
215
def test_fused_moe_batched_experts (
215
216
m : int ,
216
217
n : int ,
@@ -220,15 +221,19 @@ def test_fused_moe_batched_experts(
220
221
dtype : torch .dtype ,
221
222
per_act_token_quant : bool ,
222
223
block_shape : Optional [list [int ]],
224
+ input_scales : bool ,
223
225
):
224
226
current_platform .seed_everything (7 )
225
227
226
228
use_fp8_w8a8 = dtype == torch .float8_e4m3fn
227
229
230
+ if topk > e :
231
+ pytest .skip ("topk > e" )
232
+
228
233
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None ):
229
234
pytest .skip ("Skip quantization test for non-quantized type" )
230
235
231
- if per_act_token_quant and block_shape is not None or topk > e :
236
+ if per_act_token_quant and block_shape is not None :
232
237
pytest .skip ("Skip illegal quantization test." )
233
238
234
239
a = torch .randn ((m , k ), device = "cuda" , dtype = torch .bfloat16 ) / 10
@@ -241,55 +246,74 @@ def test_fused_moe_batched_experts(
241
246
act_dtype = dtype
242
247
quant_dtype = None
243
248
244
- _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (e ,
245
- n ,
246
- k ,
247
- block_shape = block_shape ,
248
- in_dtype = act_dtype ,
249
- quant_dtype = quant_dtype )
249
+ w1_16 , w1 , w1_s , w2_16 , w2 , w2_s = make_test_weights (
250
+ e ,
251
+ n ,
252
+ k ,
253
+ block_shape = block_shape ,
254
+ in_dtype = act_dtype ,
255
+ quant_dtype = quant_dtype ,
256
+ per_act_token_quant = per_act_token_quant ,
257
+ )
258
+
259
+ if input_scales and quant_dtype is not None :
260
+ a1_scale = torch .tensor (1 , device = "cuda" , dtype = torch .float32 )
261
+ a2_scale = torch .tensor (1 , device = "cuda" , dtype = torch .float32 )
262
+ else :
263
+ a1_scale = None
264
+ a2_scale = None
250
265
251
266
with set_current_vllm_config (vllm_config ):
252
267
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
253
- batched_output = batched_moe (
268
+
269
+ baseline_output = torch_experts (
254
270
a ,
255
271
w1 ,
256
272
w2 ,
257
273
topk_weight ,
258
274
topk_ids ,
259
275
w1_scale = w1_s ,
260
276
w2_scale = w2_s ,
277
+ a1_scale = a1_scale ,
278
+ a2_scale = a2_scale ,
261
279
quant_dtype = quant_dtype ,
262
280
per_act_token_quant = per_act_token_quant ,
263
281
block_shape = block_shape ,
264
282
)
265
- baseline_output = torch_experts (
283
+
284
+ batched_output = naive_batched_moe (
266
285
a ,
267
286
w1 ,
268
287
w2 ,
269
288
topk_weight ,
270
289
topk_ids ,
271
290
w1_scale = w1_s ,
272
291
w2_scale = w2_s ,
292
+ a1_scale = a1_scale ,
293
+ a2_scale = a2_scale ,
273
294
quant_dtype = quant_dtype ,
274
295
per_act_token_quant = per_act_token_quant ,
275
- block_shape = block_shape )
296
+ block_shape = block_shape ,
297
+ )
276
298
277
- triton_output = triton_moe (
299
+ triton_output = batched_moe (
278
300
a ,
279
301
w1 ,
280
302
w2 ,
281
303
topk_weight ,
282
304
topk_ids ,
283
305
w1_scale = w1_s ,
284
306
w2_scale = w2_s ,
307
+ a1_scale = a1_scale ,
308
+ a2_scale = a2_scale ,
285
309
quant_dtype = quant_dtype ,
286
310
per_act_token_quant = per_act_token_quant ,
287
311
block_shape = block_shape ,
288
312
)
289
313
290
- torch .testing .assert_close (triton_output ,
314
+ torch .testing .assert_close (batched_output ,
291
315
baseline_output ,
292
- atol = 2e -2 ,
316
+ atol = 3e -2 ,
293
317
rtol = 2e-2 )
294
318
295
319
torch .testing .assert_close (triton_output ,
0 commit comments