40
40
reason = "Requires PPLX kernels" ,
41
41
)
42
42
43
- PPLX_PREPARE_COMBOS = [
43
+ PPLX_COMBOS = [
44
44
# TODO: figure out why this fails
45
45
#(1, 128, 128),
46
+
46
47
(2 , 128 , 512 ),
47
48
(3 , 1024 , 2048 ),
48
49
(4 , 128 , 128 ),
49
50
(32 , 1024 , 512 ),
50
51
(45 , 512 , 2048 ),
51
52
(64 , 1024 , 512 ),
52
53
(222 , 2048 , 1024 ),
53
- ]
54
+ ( 256 , 1408 , 2048 ),
54
55
55
- PPLX_MOE_COMBOS = [
56
- # (1, 128, 128),
57
- (2 , 128 , 512 ),
58
- (3 , 1024 , 2048 ),
59
- (32 , 128 , 1024 ),
60
- (45 , 512 , 2048 ),
61
- (64 , 1024 , 1024 ),
62
- (222 , 1024 , 2048 ),
56
+ #(6, 1408, 2048),
57
+ #(16, 1408, 2048),
58
+ #(199, 1408, 2048),
59
+ #(200, 1408, 2048),
60
+ #(256, 1408, 2048),
63
61
]
64
62
65
63
NUM_EXPERTS = [8 , 64 ]
@@ -280,10 +278,17 @@ def pplx_prepare_finalize(
280
278
device = device ,
281
279
)
282
280
281
+ if quant_dtype is not None and not per_act_token_quant and block_shape is None :
282
+ a1_scale = torch .tensor (1.0 , device = "cuda" , dtype = torch .float32 )
283
+ a2_scale = torch .tensor (1.0 , device = "cuda" , dtype = torch .float32 )
284
+ else :
285
+ a1_scale = None
286
+ a2_scale = None
287
+
283
288
b_a , b_a_scale , expert_num_tokens , _ , _ = prepare_finalize .prepare (
284
289
a_chunk ,
285
- None ,
286
- None ,
290
+ a1_scale ,
291
+ a2_scale ,
287
292
chunk_topk_weight ,
288
293
chunk_topk_ids ,
289
294
num_experts ,
@@ -364,7 +369,7 @@ def _pplx_prepare_finalize(
364
369
365
370
# TODO (bnell): this test point does not work for M==1 due to how the test
366
371
# is written, not due to limitations of the pplx kernels.
367
- @pytest .mark .parametrize ("mnk" , PPLX_PREPARE_COMBOS )
372
+ @pytest .mark .parametrize ("mnk" , PPLX_COMBOS )
368
373
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
369
374
@pytest .mark .parametrize ("topk" , TOP_KS )
370
375
@pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
@@ -423,7 +428,9 @@ def pplx_moe(
423
428
topk_ids : torch .Tensor ,
424
429
w1_scale : Optional [torch .Tensor ] = None ,
425
430
w2_scale : Optional [torch .Tensor ] = None ,
426
- qtype : Optional [torch .dtype ] = None ,
431
+ a1_scale : Optional [torch .Tensor ] = None ,
432
+ a2_scale : Optional [torch .Tensor ] = None ,
433
+ quant_dtype : Optional [torch .dtype ] = None ,
427
434
per_act_token_quant = False ,
428
435
block_shape : Optional [list [int ]] = None ,
429
436
use_compile : bool = False ,
@@ -442,7 +449,7 @@ def pplx_moe(
442
449
max_num_tokens ,
443
450
hidden_dim ,
444
451
a .dtype ,
445
- qtype ,
452
+ quant_dtype ,
446
453
per_act_token_quant = per_act_token_quant ,
447
454
block_shape = block_shape ,
448
455
)
@@ -478,7 +485,7 @@ def pplx_moe(
478
485
experts = BatchedTritonExperts (max_num_tokens = max_num_tokens ,
479
486
world_size = world_size ,
480
487
dp_size = dp_size ,
481
- use_fp8_w8a8 = qtype == torch .float8_e4m3fn ,
488
+ use_fp8_w8a8 = quant_dtype == torch .float8_e4m3fn ,
482
489
block_shape = block_shape )
483
490
484
491
fused_experts = FusedMoEModularKernel (
@@ -496,12 +503,8 @@ def pplx_moe(
496
503
w2_chunk = chunk_by_rank (w2 , rank , world_size ).to (device )
497
504
498
505
if w1_scale is not None :
499
- if not per_act_token_quant :
500
- w1_scale_chunk = w1_scale
501
- w2_scale_chunk = w2_scale
502
- else :
503
- w1_scale_chunk = chunk_by_rank (w1_scale , rank , world_size ).to (device )
504
- w2_scale_chunk = chunk_by_rank (w2_scale , rank , world_size ).to (device )
506
+ w1_scale_chunk = chunk_by_rank (w1_scale , rank , world_size ).to (device )
507
+ w2_scale_chunk = chunk_by_rank (w2_scale , rank , world_size ).to (device )
505
508
else :
506
509
w1_scale_chunk = None
507
510
w2_scale_chunk = None
@@ -526,6 +529,8 @@ def pplx_moe(
526
529
chunk_topk_ids ,
527
530
w1_scale = w1_scale_chunk ,
528
531
w2_scale = w2_scale_chunk ,
532
+ a1_scale = a1_scale ,
533
+ a2_scale = a2_scale ,
529
534
global_num_experts = num_experts )
530
535
531
536
if use_cudagraphs :
@@ -540,6 +545,8 @@ def pplx_moe(
540
545
chunk_topk_ids ,
541
546
w1_scale = w1_scale_chunk ,
542
547
w2_scale = w2_scale_chunk ,
548
+ a1_scale = a1_scale ,
549
+ a2_scale = a2_scale ,
543
550
global_num_experts = num_experts )
544
551
545
552
torch .cuda .synchronize ()
@@ -562,7 +569,7 @@ def _pplx_moe(
562
569
topk : int ,
563
570
w1_s : Optional [torch .Tensor ] = None ,
564
571
w2_s : Optional [torch .Tensor ] = None ,
565
- qtype : Optional [torch .dtype ] = None ,
572
+ quant_dtype : Optional [torch .dtype ] = None ,
566
573
per_act_token_quant : bool = False ,
567
574
block_shape : Optional [list [int ]] = None ,
568
575
use_internode : bool = False ,
@@ -584,48 +591,112 @@ def _pplx_moe(
584
591
moe_config = get_default_config (m , e , n , k , topk , a .dtype , False )
585
592
586
593
device = torch .device ("cuda" , pgi .rank )
594
+ rank = pgi .rank
595
+ world_size = pgi .world_size
587
596
a = a .to (device )
588
597
w1 = w1 .to (device )
589
598
w2 = w2 .to (device )
590
599
w1_s = w1_s .to (device ) if w1_s is not None else None
591
600
w2_s = w2_s .to (device ) if w2_s is not None else None
592
601
602
+ if quant_dtype is not None and not per_act_token_quant and block_shape is None :
603
+ a1_scale = torch .tensor (1.0 , device = "cuda" , dtype = torch .float32 )
604
+ a2_scale = torch .tensor (1.0 , device = "cuda" , dtype = torch .float32 )
605
+ else :
606
+ a1_scale = None
607
+ a2_scale = None
608
+
593
609
with set_current_vllm_config (vllm_config ), override_config (moe_config ):
594
610
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
595
- torch_output = torch_experts (a ,
596
- w1 ,
597
- w2 ,
598
- topk_weight ,
599
- topk_ids ,
600
- w1_scale = w1_s ,
601
- w2_scale = w2_s ,
602
- quant_dtype = qtype ,
603
- per_act_token_quant = per_act_token_quant ,
604
- block_shape = block_shape )
605
-
606
- pplx_output = pplx_moe (group_name , pgi .rank , pgi .world_size , dp_size ,
607
- a , w1 , w2 , topk_weight , topk_ids , w1_s , w2_s ,
608
- qtype , per_act_token_quant , block_shape )
611
+
612
+ if False :
613
+ a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
614
+ topk_weight_chunk = chunk_by_rank (topk_weight , rank , world_size ).to (device )
615
+ topk_ids_chunk = chunk_by_rank (topk_ids , rank , world_size ).to (device )
616
+ w1_chunk = chunk_by_rank (w1 , rank , world_size ).to (device )
617
+ w2_chunk = chunk_by_rank (w2 , rank , world_size ).to (device )
618
+
619
+ if w1_s is not None :
620
+ w1_s_chunk = chunk_by_rank (w1_s , rank , world_size ).to (device )
621
+ w2_s_chunk = chunk_by_rank (w2_s , rank , world_size ).to (device )
622
+ else :
623
+ w1_s_chunk = None
624
+ w2_s_chunk = None
625
+ else :
626
+ a_chunk = a
627
+ topk_weight_chunk = topk_weight
628
+ topk_ids_chunk = topk_ids
629
+ w1_chunk = w1
630
+ w2_chunk = w2
631
+ w1_s_chunk = w1_s
632
+ w2_s_chunk = w2_s
633
+
634
+ torch_output = torch_experts (
635
+ a_chunk ,
636
+ w1_chunk ,
637
+ w2_chunk ,
638
+ topk_weight_chunk ,
639
+ topk_ids_chunk ,
640
+ w1_scale = w1_s_chunk ,
641
+ w2_scale = w2_s_chunk ,
642
+ a1_scale = a1_scale ,
643
+ a2_scale = a2_scale ,
644
+ quant_dtype = quant_dtype ,
645
+ per_act_token_quant = per_act_token_quant ,
646
+ block_shape = block_shape ,
647
+ )
648
+
649
+ batched_output = naive_batched_moe (
650
+ a_chunk ,
651
+ w1_chunk ,
652
+ w2_chunk ,
653
+ topk_weight_chunk ,
654
+ topk_ids_chunk ,
655
+ w1_scale = w1_s_chunk ,
656
+ w2_scale = w2_s_chunk ,
657
+ a1_scale = a1_scale ,
658
+ a2_scale = a2_scale ,
659
+ quant_dtype = quant_dtype ,
660
+ per_act_token_quant = per_act_token_quant ,
661
+ block_shape = block_shape ,
662
+ )
663
+
664
+ pplx_output = pplx_moe (
665
+ group_name ,
666
+ rank ,
667
+ world_size ,
668
+ dp_size ,
669
+ a ,
670
+ w1 ,
671
+ w2 ,
672
+ topk_weight ,
673
+ topk_ids ,
674
+ w1_scale = w1_s ,
675
+ w2_scale = w2_s ,
676
+ a1_scale = a1_scale ,
677
+ a2_scale = a2_scale ,
678
+ quant_dtype = quant_dtype ,
679
+ per_act_token_quant = per_act_token_quant ,
680
+ block_shape = block_shape )
609
681
610
682
# all reduce on pplx?
611
683
#torch.distributed.all_reduce(pplx_output)
612
684
613
- batched_output = naive_batched_moe (a , w1 , w2 , topk_weight ,
614
- topk_ids , w1_s , w2_s , qtype , per_act_token_quant , block_shape )
615
-
616
685
chunked_torch_output = chunk_by_rank (torch_output , pgi .rank ,
617
686
pgi .world_size ).to (pplx_output .device )
618
687
619
688
torch .testing .assert_close (pplx_output , chunked_torch_output , atol = 3e-2 , rtol = 3e-2 )
620
- # torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
689
+ torch .testing .assert_close (batched_output , torch_output , atol = 3e-2 , rtol = 3e-2 )
621
690
622
691
if use_internode :
623
692
nvshmem_finalize ()
624
693
625
694
626
- @pytest .mark .parametrize ("mnk" , PPLX_MOE_COMBOS )
695
+ @pytest .mark .parametrize ("mnk" , PPLX_COMBOS )
627
696
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
628
697
@pytest .mark .parametrize ("topk" , TOP_KS )
698
+ #@pytest.mark.parametrize("e", [32])
699
+ #@pytest.mark.parametrize("topk", [6])
629
700
@pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
630
701
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
631
702
@pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
0 commit comments