20
20
21
21
from tests .kernels .moe .utils import make_test_weights , naive_batched_moe
22
22
from tests .kernels .utils import torch_experts
23
- from tests .kernels .quant_utils import dequant
23
+ from tests .kernels .quant_utils import batched_dequant
24
24
from vllm .config import VllmConfig , set_current_vllm_config
25
25
from vllm .model_executor .layers .fused_moe import fused_topk , override_config
26
26
from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
41
41
)
42
42
43
43
PPLX_PREPARE_COMBOS = [
44
- # (1, 128, 128),
44
+ # TODO: figure out why this fails
45
+ #(1, 128, 128),
46
+ (2 , 128 , 512 ),
47
+ (3 , 1024 , 2048 ),
45
48
(4 , 128 , 128 ),
46
49
(32 , 1024 , 512 ),
47
- # (45, 512, 2048),
50
+ (45 , 512 , 2048 ),
48
51
(64 , 1024 , 512 ),
49
52
(222 , 2048 , 1024 ),
50
53
]
51
54
52
55
PPLX_MOE_COMBOS = [
53
- (1 , 128 , 128 ),
56
+ # (1, 128, 128),
54
57
(2 , 128 , 512 ),
55
58
(3 , 1024 , 2048 ),
56
59
(32 , 128 , 1024 ),
@@ -202,7 +205,7 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
202
205
203
206
204
207
def dummy_work (a : torch .Tensor ) -> torch .Tensor :
205
- return a # * 1.5
208
+ return a * 1.1
206
209
207
210
208
211
def pplx_prepare_finalize (
@@ -270,6 +273,13 @@ def pplx_prepare_finalize(
270
273
chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
271
274
chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
272
275
276
+ out = torch .full (
277
+ (max_num_tokens , hidden_dim ),
278
+ torch .nan ,
279
+ dtype = a .dtype ,
280
+ device = device ,
281
+ )
282
+
273
283
b_a , b_a_scale , expert_num_tokens , _ , _ = prepare_finalize .prepare (
274
284
a_chunk ,
275
285
None ,
@@ -287,16 +297,10 @@ def pplx_prepare_finalize(
287
297
),
288
298
)
289
299
290
- # Do some fake work
291
- #print(f"INTER {b_a.shape} {b_a_scale.shape if b_a_scale is not None else None}")
292
- b_a = dummy_work (dequant (b_a , b_a_scale , block_shape , per_act_token_quant , a .dtype ))
300
+ #print(f"B_A_SCALE = {b_a.shape}, {b_a_scale.shape if b_a_scale is not None else None}, {per_act_token_quant} {block_shape}, {a_chunk.shape}")
301
+ # TOOD: shouldn't need batched_dequant
293
302
294
- out = torch .full (
295
- (max_num_tokens , hidden_dim ),
296
- torch .nan ,
297
- dtype = a .dtype ,
298
- device = device ,
299
- )
303
+ b_a = dummy_work (batched_dequant (b_a , b_a_scale , block_shape , per_act_token_quant , a .dtype ))
300
304
301
305
prepare_finalize .finalize (
302
306
out ,
@@ -338,49 +342,34 @@ def _pplx_prepare_finalize(
338
342
cpu_group = torch .distributed .new_group (group_ranks , backend = "gloo" )
339
343
group_name = cpu_group .group_name
340
344
341
- #device = pgi.device
342
-
343
345
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
344
346
m , k = a .shape
345
347
346
- a_rep = torch .repeat_interleave (dummy_work (a ), topk , dim = 0 ) #.to(device)
348
+ a_rep = torch .repeat_interleave (dummy_work (a ), topk , dim = 0 )
347
349
348
- if True :
349
- torch_output = (a_rep .view (m , topk , k ) *
350
- topk_weight .view (m , topk , 1 ).to (a_rep .dtype )).sum (dim = 1 )
351
- else :
352
- import vllm ._custom_ops as ops
353
- a_rep = a_rep .view (m , topk , k )
354
- a_rep .mul_ (topk_weight .view (m , topk , 1 ).to (a_rep .dtype ))
355
- torch_output = torch .empty_like (a )
356
- ops .moe_sum (a_rep , torch_output )
350
+ torch_output = (a_rep .view (m , topk , k ) *
351
+ topk_weight .view (m , topk , 1 ).to (a_rep .dtype )).sum (dim = 1 )
357
352
358
353
pplx_output = pplx_prepare_finalize (pgi , dp_size , a , topk_weight , topk_ids ,
359
354
num_experts , quant_dtype , block_shape ,
360
355
per_act_token_quant , group_name )
361
356
362
- torch_output = chunk_by_rank (torch_output , pgi .rank ,
363
- pgi .world_size ).to (pplx_output .device )
364
-
365
- #torch.set_printoptions(profile="full")
366
- #print(f"PPLX {pplx_output.shape}\n{pplx_output.shape}")
367
- #print(f"TORCH {torch_output.shape}\n{torch_output.shape}")
357
+ torch_output = chunk_by_rank (torch_output , pgi .rank , pgi .world_size ).to (pgi .device )
368
358
369
359
torch .testing .assert_close (pplx_output , torch_output , atol = 3e-2 , rtol = 3e-2 )
370
360
371
361
if use_internode :
372
362
nvshmem_finalize ()
373
363
374
364
375
- # TODO (bnell): this test point does not work for odd M due to how the test is
376
- # written, not due to limitations of the pplx kernels. The pplx_moe
377
- # test below is able to deal with odd M.
365
+ # TODO (bnell): this test point does not work for M==1 due to how the test
366
+ # is written, not due to limitations of the pplx kernels.
378
367
@pytest .mark .parametrize ("mnk" , PPLX_PREPARE_COMBOS )
379
368
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
380
369
@pytest .mark .parametrize ("topk" , TOP_KS )
381
370
@pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
382
371
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
383
- @pytest .mark .parametrize ("per_act_token_quant" , [False ])
372
+ @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
384
373
@pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
385
374
@pytest .mark .parametrize ("use_internode" , [False ])
386
375
@requires_pplx
@@ -414,8 +403,6 @@ def test_pplx_prepare_finalize(
414
403
world_size , dp_size = world_dp_size
415
404
device = "cuda"
416
405
417
- #print(f"MNK = {mnk}")
418
-
419
406
a = torch .randn ((m , k ), device = device , dtype = act_dtype ) / 10
420
407
score = torch .randn ((m , e ), device = device , dtype = act_dtype )
421
408
@@ -508,10 +495,13 @@ def pplx_moe(
508
495
w1_chunk = chunk_by_rank (w1 , rank , world_size ).to (device )
509
496
w2_chunk = chunk_by_rank (w2 , rank , world_size ).to (device )
510
497
511
- # TODO scale chunk function
512
498
if w1_scale is not None :
513
- w1_scale_chunk = chunk_by_rank (w1_scale , rank , world_size ).to (device )
514
- w2_scale_chunk = chunk_by_rank (w2_scale , rank , world_size ).to (device )
499
+ if not per_act_token_quant :
500
+ w1_scale_chunk = w1_scale
501
+ w2_scale_chunk = w2_scale
502
+ else :
503
+ w1_scale_chunk = chunk_by_rank (w1_scale , rank , world_size ).to (device )
504
+ w2_scale_chunk = chunk_by_rank (w2_scale , rank , world_size ).to (device )
515
505
else :
516
506
w1_scale_chunk = None
517
507
w2_scale_chunk = None
@@ -562,48 +552,6 @@ def pplx_moe(
562
552
return out
563
553
564
554
565
- def _batched_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids ):
566
- assert torch .cuda .current_device () == pgi .local_rank
567
-
568
- num_experts = w1 .shape [0 ]
569
- device = pgi .device
570
- rank = pgi .rank
571
- world_size = pgi .world_size
572
- max_num_tokens = rank_chunk (a .shape [0 ], 0 , world_size )
573
-
574
- prepare_finalize = BatchedPrepareAndFinalize (
575
- max_num_tokens = max_num_tokens ,
576
- world_size = world_size ,
577
- dp_size = dp_size ,
578
- rank = rank ,
579
- )
580
-
581
- experts = NaiveBatchedExperts (max_num_tokens = a .shape [0 ],
582
- world_size = 1 ,
583
- dp_size = 1 )
584
-
585
- fused_experts = FusedMoEModularKernel (
586
- prepare_finalize ,
587
- experts ,
588
- )
589
-
590
- # Note: workers with the same dp_rank must use the exact same inputs.
591
- a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
592
- chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
593
- chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
594
-
595
- out = fused_experts (
596
- a_chunk ,
597
- # Chunking weights like this only works for batched format
598
- chunk_by_rank (w1 , rank , world_size ).to (device ),
599
- chunk_by_rank (w2 , rank , world_size ).to (device ),
600
- chunk_topk_weight ,
601
- chunk_topk_ids ,
602
- global_num_experts = num_experts )
603
-
604
- return out
605
-
606
-
607
555
def _pplx_moe (
608
556
pgi : ProcessGroupInfo ,
609
557
dp_size : int ,
@@ -654,18 +602,22 @@ def _pplx_moe(
654
602
quant_dtype = qtype ,
655
603
per_act_token_quant = per_act_token_quant ,
656
604
block_shape = block_shape )
605
+
657
606
pplx_output = pplx_moe (group_name , pgi .rank , pgi .world_size , dp_size ,
658
607
a , w1 , w2 , topk_weight , topk_ids , w1_s , w2_s ,
659
608
qtype , per_act_token_quant , block_shape )
660
- # TODO (bnell): fix + re-enable
661
- #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
662
- # topk_ids)
663
609
664
- torch_output = chunk_by_rank (torch_output , pgi .rank ,
665
- pgi .world_size ).to (pplx_output .device )
610
+ # all reduce on pplx?
611
+ #torch.distributed.all_reduce(pplx_output)
612
+
613
+ batched_output = naive_batched_moe (a , w1 , w2 , topk_weight ,
614
+ topk_ids , w1_s , w2_s , qtype , per_act_token_quant , block_shape )
615
+
616
+ chunked_torch_output = chunk_by_rank (torch_output , pgi .rank ,
617
+ pgi .world_size ).to (pplx_output .device )
666
618
667
- torch .testing .assert_close (pplx_output , torch_output , atol = 2e -2 , rtol = 0 )
668
- #torch.testing.assert_close(batched_output, torch_output, atol=2e -2, rtol=0 )
619
+ torch .testing .assert_close (pplx_output , chunked_torch_output , atol = 3e -2 , rtol = 3e-2 )
620
+ #torch.testing.assert_close(batched_output, torch_output, atol=3e -2, rtol=3e-2 )
669
621
670
622
if use_internode :
671
623
nvshmem_finalize ()
0 commit comments