1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
3
from dataclasses import dataclass
4
+ from typing import Optional
4
5
5
6
import pytest
6
7
import torch
7
8
import triton .language as tl
8
- from typing import Optional
9
9
10
10
import vllm ._custom_ops as ops
11
11
from vllm .config import VllmConfig , set_current_vllm_config
12
12
from vllm .model_executor .layers .activation import SiluAndMul
13
13
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
14
- invoke_moe_batched_triton_kernel ,
15
- BatchedExperts ,
16
- BatchedPrepareAndFinalize ,
17
- BatchedTritonExperts )
18
- from vllm .model_executor .layers .fused_moe .fused_moe import (fused_topk ,
19
- get_default_config )
14
+ BatchedPrepareAndFinalize , BatchedTritonExperts ,
15
+ invoke_moe_batched_triton_kernel )
16
+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
20
17
from vllm .model_executor .layers .fused_moe .modular_kernel import (
21
18
FusedMoEModularKernel )
22
19
from vllm .model_executor .layers .quantization .utils .fp8_utils import (
23
- per_token_group_quant_fp8 , w8a8_block_fp8_matmul )
20
+ per_token_group_quant_fp8 )
24
21
from vllm .platforms import current_platform
25
22
from vllm .utils import round_up
26
23
27
-
28
24
NUM_EXPERTS = [8 , 64 ]
29
25
TOP_KS = [1 , 2 , 6 ]
30
26
@@ -80,10 +76,12 @@ def make_tensors(config: BatchedMMConfig):
80
76
return BatchedMMTensors (A , B , C , num_expert_tokens )
81
77
82
78
83
- def native_w8a8_block_matmul (A : torch .Tensor , B : torch .Tensor ,
84
- As : torch .Tensor , Bs : torch .Tensor ,
79
+ def native_w8a8_block_matmul (A : torch .Tensor ,
80
+ B : torch .Tensor ,
81
+ As : torch .Tensor ,
82
+ Bs : torch .Tensor ,
85
83
block_size ,
86
- output_dtype = torch .bfloat16 ):
84
+ output_dtype = torch .bfloat16 ):
87
85
"""This function performs matrix multiplication with block-wise
88
86
quantization using native torch.
89
87
It is agnostic to the input data type and can be used for both int8 and
@@ -160,16 +158,11 @@ def ref_impl(
160
158
if A .dtype == torch .torch .float8_e4m3fn :
161
159
if False :
162
160
tmp = native_w8a8_block_matmul (A [e , :, :],
163
- B [e ].transpose (0 , 1 ),
164
- A_scale ,
165
- B_scale ,
166
- block_shape )
161
+ B [e ].transpose (0 , 1 ), A_scale ,
162
+ B_scale , block_shape )
167
163
else :
168
- tmp = ops .cutlass_scaled_mm (A [e , :, :],
169
- B [e ].transpose (0 , 1 ),
170
- A_scale ,
171
- B_scale ,
172
- torch .bfloat16 )
164
+ tmp = ops .cutlass_scaled_mm (A [e , :, :], B [e ].transpose (0 , 1 ),
165
+ A_scale , B_scale , torch .bfloat16 )
173
166
C [e , :num_tokens , :] = tmp [:num_tokens , :]
174
167
else :
175
168
C [e , :num_tokens , :] = A [e , :num_tokens , :] @ B [e ].transpose (0 , 1 )
@@ -195,7 +188,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
195
188
in_dtype = dtype
196
189
out_dtype = dtype
197
190
198
- config = BatchedMMConfig (in_dtype , out_dtype , num_experts , max_tokens_per_expert , K , N )
191
+ config = BatchedMMConfig (in_dtype , out_dtype , num_experts ,
192
+ max_tokens_per_expert , K , N )
199
193
tensors = BatchedMMTensors .make_tensors (config )
200
194
201
195
test_output = tensors .C
@@ -209,7 +203,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
209
203
}[test_output .dtype ]
210
204
211
205
use_fp8_w8a8 = dtype == torch .torch .float8_e4m3fn
212
- block_shape = [16 , 16 , 32 ] # 16 for k if not fp8
206
+ block_shape = [16 , 16 , 32 ] # 16 for k if not fp8
213
207
214
208
#print(f"tensors.A {tensors.A.shape}")
215
209
#print(f"tensors.B {tensors.B.shape}")
@@ -250,19 +244,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
250
244
251
245
ref_output = ref_output .to (dtype = out_dtype )
252
246
ref_output = ref_impl (tensors .A .to (dtype = out_dtype ),
253
- tensors .B .to (dtype = out_dtype ),
254
- ref_output ,
255
- tensors .num_expert_tokens ,
256
- A_scale ,
257
- B_scale ,
247
+ tensors .B .to (dtype = out_dtype ), ref_output ,
248
+ tensors .num_expert_tokens , A_scale , B_scale ,
258
249
block_shape [- 2 :])
259
250
260
- ref_output2 = ref_impl (tensors .A ,
261
- tensors .B ,
262
- ref_output2 ,
263
- tensors .num_expert_tokens ,
264
- A_scale ,
265
- B_scale ,
251
+ ref_output2 = ref_impl (tensors .A , tensors .B , ref_output2 ,
252
+ tensors .num_expert_tokens , A_scale , B_scale ,
266
253
block_shape [- 2 :])
267
254
268
255
rtol , atol = {
@@ -286,11 +273,17 @@ def batched_moe(
286
273
use_fp8_w8a8 : bool = False ,
287
274
block_shape : Optional [list [int ]] = None ,
288
275
) -> torch .Tensor :
289
- max_num_tokens = round_up (a .shape [0 ], 64 ) # ?
276
+ max_num_tokens = round_up (a .shape [0 ], 64 ) # ?
290
277
fused_experts = FusedMoEModularKernel (
291
- BatchedPrepareAndFinalize (max_num_tokens , world_size = 1 , dp_size = 1 , rank = 0 , use_fp8_w8a8 = use_fp8_w8a8 ,
278
+ BatchedPrepareAndFinalize (max_num_tokens ,
279
+ world_size = 1 ,
280
+ dp_size = 1 ,
281
+ rank = 0 ,
282
+ use_fp8_w8a8 = use_fp8_w8a8 ,
292
283
block_shape = block_shape ),
293
- BatchedTritonExperts (max_num_tokens = max_num_tokens , dp_size = 1 , world_size = 1 ,
284
+ BatchedTritonExperts (max_num_tokens = max_num_tokens ,
285
+ dp_size = 1 ,
286
+ world_size = 1 ,
294
287
use_fp8_w8a8 = use_fp8_w8a8 ,
295
288
block_shape = block_shape ))
296
289
@@ -322,11 +315,13 @@ def torch_moe2(
322
315
323
316
if use_fp8_w8a8 :
324
317
a , a_scale = per_token_group_quant_fp8 (a , block_shape [1 ])
325
- #print(f"a_scale {a_scale.shape}")
326
318
else :
327
319
a_scale = None
328
320
329
- out = torch .zeros (M * topk , w2 .shape [1 ], dtype = torch .bfloat16 , device = a .device )
321
+ out = torch .zeros (M * topk ,
322
+ w2 .shape [1 ],
323
+ dtype = torch .bfloat16 ,
324
+ device = a .device )
330
325
num_experts = w1 .shape [0 ]
331
326
for i in range (num_experts ):
332
327
mask = (topk_ids == i ).view (- 1 )
@@ -341,11 +336,8 @@ def torch_moe2(
341
336
# a_scale[mask],
342
337
# w1_scale[i],
343
338
# torch.bfloat16)
344
- tmp1 = native_w8a8_block_matmul (a [mask ],
345
- w1 [i ],
346
- a_scale [mask ],
347
- w1_scale [i ],
348
- block_shape ,
339
+ tmp1 = native_w8a8_block_matmul (a [mask ], w1 [i ], a_scale [mask ],
340
+ w1_scale [i ], block_shape ,
349
341
torch .bfloat16 )
350
342
tmp2 = SiluAndMul ()(tmp1 )
351
343
tmp2 , b_scale = per_token_group_quant_fp8 (tmp2 , block_shape [1 ])
@@ -355,11 +347,8 @@ def torch_moe2(
355
347
# b_scale,
356
348
# w2_scale[i],
357
349
# torch.bfloat16)
358
- out [mask ] = native_w8a8_block_matmul (tmp2 ,
359
- w2 [i ],
360
- b_scale ,
361
- w2_scale [i ],
362
- block_shape ,
350
+ out [mask ] = native_w8a8_block_matmul (tmp2 , w2 [i ], b_scale ,
351
+ w2_scale [i ], block_shape ,
363
352
torch .bfloat16 )
364
353
365
354
return (out .view (M , - 1 , w2 .shape [1 ]) *
@@ -406,23 +395,21 @@ def test_fused_moe_batched_experts(
406
395
407
396
factor_for_scale = 1e-2
408
397
w1_s = torch .rand (
409
- (e , n_tiles_w1 , k_tiles_w1 ), dtype = torch .float32 , device = "cuda" ) * factor_for_scale
398
+ (e , n_tiles_w1 , k_tiles_w1 ), dtype = torch .float32 ,
399
+ device = "cuda" ) * factor_for_scale
410
400
w2_s = torch .rand (
411
- (e , n_tiles_w2 , k_tiles_w2 ), dtype = torch .float32 , device = "cuda" ) * factor_for_scale
401
+ (e , n_tiles_w2 , k_tiles_w2 ), dtype = torch .float32 ,
402
+ device = "cuda" ) * factor_for_scale
412
403
else :
413
404
w1_s = None
414
405
w2_s = None
415
406
416
407
with set_current_vllm_config (vllm_config ):
417
408
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
418
- baseline_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids , w1_s , w2_s , use_fp8_w8a8 , block_shape )
419
- batched_output = batched_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s , w2_s , use_fp8_w8a8 , block_shape )
420
- # batched_output = batched_moe(a,
421
- # w1.to(torch.bfloat16),
422
- # w2.to(torch.bfloat16),
423
- # topk_weight, topk_ids,
424
- # w1_s, w2_s, False,
425
- # block_shape)
409
+ baseline_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
410
+ w2_s , use_fp8_w8a8 , block_shape )
411
+ batched_output = batched_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
412
+ w2_s , use_fp8_w8a8 , block_shape )
426
413
427
414
torch .testing .assert_close (baseline_output ,
428
415
batched_output ,
0 commit comments