20
20
21
21
from tests .kernels .moe .utils import make_test_weights , naive_batched_moe
22
22
from tests .kernels .utils import torch_experts
23
+ from tests .kernels .quant_utils import dequant
23
24
from vllm .config import VllmConfig , set_current_vllm_config
24
25
from vllm .model_executor .layers .fused_moe import fused_topk , override_config
25
26
from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
39
40
reason = "Requires PPLX kernels" ,
40
41
)
41
42
42
- PPLX_PREPARE_COMBOS = [(4 , 128 , 128 ), (32 , 1024 , 512 ), (64 , 1024 , 512 ),
43
- (222 , 2048 , 1024 )]
43
+ PPLX_PREPARE_COMBOS = [
44
+ # (1, 128, 128),
45
+ (4 , 128 , 128 ),
46
+ (32 , 1024 , 512 ),
47
+ # (45, 512, 2048),
48
+ (64 , 1024 , 512 ),
49
+ (222 , 2048 , 1024 ),
50
+ ]
44
51
45
52
PPLX_MOE_COMBOS = [
46
53
(1 , 128 , 128 ),
@@ -194,18 +201,24 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
194
201
return t [(r * chunk ):(r + 1 ) * chunk ]
195
202
196
203
204
+ def dummy_work (a : torch .Tensor ) -> torch .Tensor :
205
+ return a # * 1.5
206
+
207
+
197
208
def pplx_prepare_finalize (
198
209
pgi : ProcessGroupInfo ,
199
210
dp_size : int ,
200
211
a : torch .Tensor ,
201
- a_scale : Optional [torch .Tensor ],
202
212
topk_weight : torch .Tensor ,
203
213
topk_ids : torch .Tensor ,
204
214
num_experts : int ,
215
+ quant_dtype : Optional [torch .dtype ],
216
+ block_shape : Optional [list [int ]],
217
+ per_act_token_quant : bool ,
205
218
group_name : Optional [str ],
206
219
) -> torch .Tensor :
207
220
from vllm .model_executor .layers .fused_moe .pplx_prepare_finalize import (
208
- PplxPrepareAndFinalize )
221
+ PplxPrepareAndFinalize , pplx_hidden_dim_scale_bytes )
209
222
210
223
assert torch .cuda .current_device () == pgi .local_rank
211
224
@@ -214,7 +227,16 @@ def pplx_prepare_finalize(
214
227
device = pgi .device
215
228
rank = pgi .rank
216
229
world_size = pgi .world_size
217
- max_num_tokens = rank_chunk (num_tokens , 0 , world_size )
230
+ max_num_tokens = max (rank_chunk (num_tokens , 0 , world_size ), 1 )
231
+
232
+ hidden_dim_bytes , scale_bytes = pplx_hidden_dim_scale_bytes (
233
+ max_num_tokens ,
234
+ hidden_dim ,
235
+ a .dtype ,
236
+ quant_dtype ,
237
+ per_act_token_quant = per_act_token_quant ,
238
+ block_shape = block_shape ,
239
+ )
218
240
219
241
args = dict (
220
242
max_num_tokens = max_num_tokens ,
@@ -224,8 +246,8 @@ def pplx_prepare_finalize(
224
246
world_size = world_size ,
225
247
dp_size = dp_size ,
226
248
hidden_dim = hidden_dim ,
227
- hidden_dim_bytes = hidden_dim * a . dtype . itemsize ,
228
- hidden_dim_scale_bytes = 0 ,
249
+ hidden_dim_bytes = hidden_dim_bytes ,
250
+ hidden_dim_scale_bytes = scale_bytes ,
229
251
)
230
252
231
253
if group_name is None :
@@ -257,10 +279,17 @@ def pplx_prepare_finalize(
257
279
num_experts ,
258
280
None ,
259
281
False ,
260
- FusedMoEQuantConfig (),
282
+ FusedMoEQuantConfig (
283
+ quant_dtype ,
284
+ per_act_token_quant ,
285
+ False ,
286
+ block_shape ,
287
+ ),
261
288
)
262
289
263
- b_a = b_a * 1.5
290
+ # Do some fake work
291
+ #print(f"INTER {b_a.shape} {b_a_scale.shape if b_a_scale is not None else None}")
292
+ b_a = dummy_work (dequant (b_a , b_a_scale , block_shape , per_act_token_quant , a .dtype ))
264
293
265
294
out = torch .full (
266
295
(max_num_tokens , hidden_dim ),
@@ -290,10 +319,12 @@ def _pplx_prepare_finalize(
290
319
pgi : ProcessGroupInfo ,
291
320
dp_size : int ,
292
321
a : torch .Tensor ,
293
- a_scale : Optional [torch .Tensor ],
294
322
score : torch .Tensor ,
295
323
topk : torch .Tensor ,
296
324
num_experts : int ,
325
+ quant_dtype : Optional [torch .dtype ],
326
+ block_shape : Optional [list [int ]],
327
+ per_act_token_quant : bool ,
297
328
use_internode : bool ,
298
329
):
299
330
if use_internode :
@@ -307,24 +338,35 @@ def _pplx_prepare_finalize(
307
338
cpu_group = torch .distributed .new_group (group_ranks , backend = "gloo" )
308
339
group_name = cpu_group .group_name
309
340
310
- device = pgi .device
341
+ # device = pgi.device
311
342
312
343
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
313
- k = a .shape [ 1 ]
344
+ m , k = a .shape
314
345
315
- a_rep = torch .repeat_interleave (a , topk , dim = 0 ).to (device )
346
+ a_rep = torch .repeat_interleave (dummy_work ( a ) , topk , dim = 0 ) # .to(device)
316
347
317
- torch_output = (a_rep .view (- 1 , topk , k ) * 1.5 *
318
- topk_weight .view (- 1 , topk , 1 ).to (device )).sum (dim = 1 ).to (
319
- a .dtype )
348
+ if True :
349
+ torch_output = (a_rep .view (m , topk , k ) *
350
+ topk_weight .view (m , topk , 1 ).to (a_rep .dtype )).sum (dim = 1 )
351
+ else :
352
+ import vllm ._custom_ops as ops
353
+ a_rep = a_rep .view (m , topk , k )
354
+ a_rep .mul_ (topk_weight .view (m , topk , 1 ).to (a_rep .dtype ))
355
+ torch_output = torch .empty_like (a )
356
+ ops .moe_sum (a_rep , torch_output )
320
357
321
- pplx_output = pplx_prepare_finalize (pgi , dp_size , a , a_scale , topk_weight , topk_ids ,
322
- num_experts , group_name )
358
+ pplx_output = pplx_prepare_finalize (pgi , dp_size , a , topk_weight , topk_ids ,
359
+ num_experts , quant_dtype , block_shape ,
360
+ per_act_token_quant , group_name )
323
361
324
362
torch_output = chunk_by_rank (torch_output , pgi .rank ,
325
363
pgi .world_size ).to (pplx_output .device )
326
364
327
- torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
365
+ #torch.set_printoptions(profile="full")
366
+ #print(f"PPLX {pplx_output.shape}\n{pplx_output.shape}")
367
+ #print(f"TORCH {torch_output.shape}\n{torch_output.shape}")
368
+
369
+ torch .testing .assert_close (pplx_output , torch_output , atol = 3e-2 , rtol = 3e-2 )
328
370
329
371
if use_internode :
330
372
nvshmem_finalize ()
@@ -333,11 +375,10 @@ def _pplx_prepare_finalize(
333
375
# TODO (bnell): this test point does not work for odd M due to how the test is
334
376
# written, not due to limitations of the pplx kernels. The pplx_moe
335
377
# test below is able to deal with odd M.
336
- # TODO (bnell) add fp8 tests
337
378
@pytest .mark .parametrize ("mnk" , PPLX_PREPARE_COMBOS )
338
379
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
339
380
@pytest .mark .parametrize ("topk" , TOP_KS )
340
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ]) # torch. float8_e4m3fn,
381
+ @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch . bfloat16 ])
341
382
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
342
383
@pytest .mark .parametrize ("per_act_token_quant" , [False ])
343
384
@pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
@@ -356,28 +397,31 @@ def test_pplx_prepare_finalize(
356
397
if dtype == torch .float8_e4m3fn :
357
398
use_fp8_w8a8 = True
358
399
act_dtype = torch .bfloat16
400
+ quant_dtype = dtype
359
401
else :
360
402
use_fp8_w8a8 = False
361
403
act_dtype = dtype
404
+ quant_dtype = None
362
405
363
406
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None ):
364
407
pytest .skip ("Skip quantization test for non-quantized type" )
365
408
366
409
if per_act_token_quant and block_shape is not None :
367
- pytest .skip ("Skip illgal quantization combination" )
410
+ pytest .skip ("Skip illegal quantization combination" )
368
411
369
412
current_platform .seed_everything (7 )
370
413
m , n , k = mnk
371
414
world_size , dp_size = world_dp_size
372
415
device = "cuda"
373
416
417
+ #print(f"MNK = {mnk}")
418
+
374
419
a = torch .randn ((m , k ), device = device , dtype = act_dtype ) / 10
375
420
score = torch .randn ((m , e ), device = device , dtype = act_dtype )
376
421
377
- a , a_scale = moe_kernel_quantize_input (a , None , dtype , False , block_shape )
378
-
379
- parallel_launch (world_size , _pplx_prepare_finalize , dp_size , a , a_scale , score ,
380
- topk , e , use_internode )
422
+ parallel_launch (world_size , _pplx_prepare_finalize , dp_size ,
423
+ a , score , topk , e , quant_dtype , block_shape ,
424
+ per_act_token_quant , use_internode )
381
425
382
426
383
427
def pplx_moe (
@@ -661,7 +705,7 @@ def test_pplx_moe(
661
705
pytest .skip ("Skip quantization test for non-quantized type" )
662
706
663
707
if per_act_token_quant and block_shape is not None :
664
- pytest .skip ("Skip illgal quantization combination" )
708
+ pytest .skip ("Skip illegal quantization combination" )
665
709
666
710
a = torch .randn ((m , k ), device = "cuda" , dtype = torch .bfloat16 ) / 10
667
711
score = torch .randn ((m , e ), device = "cuda" , dtype = torch .bfloat16 )
0 commit comments