9
9
import vllm .model_executor .layers .fused_moe .modular_kernel as mk
10
10
from vllm .model_executor .layers .fused_moe .fused_moe import (
11
11
get_config_dtype_str , try_get_optimal_moe_config )
12
- from vllm .model_executor .layers .fused_moe .utils import _resize_cache
13
- from vllm . model_executor . layers . quantization . utils . fp8_utils import (
14
- per_token_group_quant_fp8 )
12
+ from vllm .model_executor .layers .fused_moe .utils import (
13
+ _resize_cache ,
14
+ moe_kernel_quantize_input )
15
15
16
16
17
17
@triton .jit
@@ -47,6 +47,7 @@ def moe_mmk(
47
47
compute_type : tl .constexpr ,
48
48
use_w8a8 : tl .constexpr ,
49
49
use_w8a16 : tl .constexpr ):
50
+
50
51
offs_k = tl .arange (0 , BLOCK_K )
51
52
52
53
if use_w8a16 :
@@ -325,6 +326,7 @@ def invoke_moe_batched_triton_kernel(
325
326
use_int4_w4a16 : bool ,
326
327
config : dict [str , int ],
327
328
block_shape : Optional [list [int ]] = None ):
329
+
328
330
assert not use_int4_w4a16
329
331
max_num_tokens = A .size (1 )
330
332
K = A .size (2 )
@@ -393,15 +395,17 @@ def __init__(self,
393
395
world_size : int ,
394
396
dp_size : int ,
395
397
rank : int ,
396
- use_fp8_w8a8 : bool = False ,
398
+ qtype : Optional [torch .dtype ] = None ,
399
+ per_act_token : bool = False ,
397
400
block_shape : Optional [list [int ]] = None ):
398
401
super ().__init__ ()
399
402
self .world_size = world_size
400
403
self .dp_size = dp_size
401
404
self .rank = rank
402
405
self .max_num_tokens = max_num_tokens
403
- self .use_fp8_w8a8 = use_fp8_w8a8
406
+ self .per_act_token = per_act_token
404
407
self .block_shape = block_shape
408
+ self .qtype = qtype
405
409
406
410
def prepare (
407
411
self ,
@@ -445,10 +449,10 @@ def prepare(
445
449
446
450
b_a1 = torch .zeros (
447
451
(num_local_experts , self .max_num_tokens , hidden_dim ),
448
- dtype = torch . float8_e4m3fn if self .use_fp8_w8a8 else a1 .dtype ,
452
+ dtype = self . qtype if self .qtype is not None else a1 .dtype ,
449
453
device = a1 .device )
450
454
451
- if self .use_fp8_w8a8 :
455
+ if self .qtype is not None :
452
456
k_tiles = (hidden_dim + block_k - 1 ) // block_k
453
457
b_a1_scale = torch .zeros (
454
458
(num_local_experts , self .max_num_tokens , k_tiles ),
@@ -465,10 +469,20 @@ def prepare(
465
469
rows = torch .count_nonzero (topks .flatten ())
466
470
rhs = a1 [:topks .numel ()][topks ]
467
471
idx = expert_id - first_expert
468
- if self .use_fp8_w8a8 :
469
- # TODO: use _fp8_quantize
470
- b_a1 [idx , :rows , :], b_a1_scale [
471
- idx , :rows ] = per_token_group_quant_fp8 (rhs , block_k )
472
+ if self .qtype is not None :
473
+ if a1_scale is not None :
474
+ rhs_a1_scale = a1_scale [:topks .numel ()][topks ]
475
+ else :
476
+ rhs_a1_scale = None
477
+ b_a1 [idx , :rows , :], b_a1_scale [idx , :rows ] = (
478
+ moe_kernel_quantize_input (
479
+ rhs ,
480
+ rhs_a1_scale ,
481
+ self .qtype ,
482
+ self .per_act_token ,
483
+ self .block_shape ,
484
+ )
485
+ )
472
486
else :
473
487
b_a1 [idx , :rows , :] = rhs
474
488
@@ -524,7 +538,6 @@ def __init__(
524
538
block_m : Optional [int ] = None ,
525
539
):
526
540
super ().__init__ ()
527
- #assert block_shape is None
528
541
assert block_m is None
529
542
assert not use_int8_w8a8 , "NYI"
530
543
assert not use_int8_w8a16 , "NYI"
@@ -615,6 +628,42 @@ def apply(
615
628
return out
616
629
617
630
631
+ def batched_moe_kernel_quantize_input (
632
+ A : torch .Tensor ,
633
+ A_scale : Optional [torch .Tensor ],
634
+ num_tokens : int ,
635
+ E : int ,
636
+ N : int ,
637
+ expert_num_tokens : torch .Tensor ,
638
+ qtype : Optional [torch .dtype ],
639
+ per_channel_quant : bool ,
640
+ block_shape : Optional [list [int ]] = None ,
641
+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
642
+ if qtype is not None :
643
+ assert block_shape is not None
644
+ A_q = torch .empty_like (A , dtype = qtype )
645
+ block_n , block_k = block_shape
646
+ n_tiles = ((N // 2 ) + block_n - 1 ) // block_n
647
+ scale_shape = (E , num_tokens , n_tiles )
648
+ A_q_scale = torch .empty (scale_shape ,
649
+ dtype = torch .float32 ,
650
+ device = A .device )
651
+ for e in range (E ):
652
+ num_tokens = expert_num_tokens [e ]
653
+ if num_tokens > 0 :
654
+ A_q [e , :num_tokens , :], tmp_scale = moe_kernel_quantize_input (
655
+ A [e , :num_tokens ],
656
+ A_scale [e , :num_tokens ] if A_scale else None ,
657
+ qtype ,
658
+ per_channel_quant ,
659
+ [block_k , block_n ])
660
+ A_q_scale [e , :tmp_scale .shape [0 ]] = tmp_scale
661
+
662
+ return A_q , A_q_scale
663
+ else :
664
+ return A , A_scale
665
+
666
+
618
667
class BatchedTritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
619
668
"""
620
669
A Triton based MoE expert class that operates on expert batched format,
@@ -630,6 +679,7 @@ def __init__(
630
679
use_int8_w8a16 : bool = False ,
631
680
use_int4_w4a16 : bool = False ,
632
681
block_shape : Optional [list [int ]] = None ,
682
+ per_act_token : bool = False ,
633
683
world_size : int = 1 ,
634
684
dp_size : int = 1 ,
635
685
):
@@ -644,6 +694,8 @@ def __init__(
644
694
assert not use_int4_w4a16 , "NYI"
645
695
self .world_size = world_size
646
696
self .dp_size = dp_size
697
+ self .per_act_token = per_act_token
698
+ self .qtype = torch .float8_e4m3fn if self .use_fp8_w8a8 else None
647
699
648
700
def workspace_shapes (
649
701
self ,
@@ -731,7 +783,6 @@ def apply(
731
783
raise ValueError (
732
784
f"Unsupported compute_type: { hidden_states .dtype } " )
733
785
734
- #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
735
786
# We can reuse the memory between these because by the time we need
736
787
# cache3, we're done with cache1
737
788
intermediate_cache1 = _resize_cache (workspace13 , (E , num_tokens , N ))
@@ -761,36 +812,17 @@ def apply(
761
812
self .activation (activation , intermediate_cache2 .view (- 1 , N // 2 ),
762
813
intermediate_cache1 .view (- 1 , N ))
763
814
764
- #qintermediate_cache2 = intermediate_cache2
765
-
766
- # TODO (varun) : support w8a8
767
- #assert not self.use_fp8_w8a8
768
- if self .use_fp8_w8a8 :
769
- per_act_token = False
770
- # TODO: reuse?
771
- qintermediate_cache2 = torch .empty_like (intermediate_cache2 ,
772
- dtype = torch .float8_e4m3fn )
773
- block_n = self .block_shape [0 ]
774
- n_tiles = ((N // 2 ) + block_n - 1 ) // block_n
775
- scale_shape = (E , num_tokens , n_tiles )
776
- a2q_scale = torch .empty (scale_shape ,
777
- dtype = torch .float32 ,
778
- device = hidden_states .device )
779
- for e in range (E ):
780
- num_tokens = expert_num_tokens [e ]
781
- if num_tokens > 0 :
782
- #qintermediate_cache2[e], tmp_scale = _fp8_quantize(
783
- # intermediate_cache2[e],
784
- # a2_scale[e] if a2_scale is not None else None,
785
- # per_act_token, self.block_shape)
786
- qintermediate_cache2 [
787
- e , :
788
- num_tokens , :], tmp_scale = per_token_group_quant_fp8 (
789
- intermediate_cache2 [e , :num_tokens ], block_n )
790
- a2q_scale [e , :tmp_scale .shape [0 ]] = tmp_scale
791
- else :
792
- qintermediate_cache2 = intermediate_cache2
793
- a2q_scale = a2_scale
815
+ qintermediate_cache2 , a2q_scale = batched_moe_kernel_quantize_input (
816
+ intermediate_cache2 ,
817
+ a2_scale ,
818
+ num_tokens ,
819
+ E ,
820
+ N ,
821
+ expert_num_tokens ,
822
+ self .qtype ,
823
+ self .per_act_token ,
824
+ self .block_shape
825
+ )
794
826
795
827
invoke_moe_batched_triton_kernel (A = qintermediate_cache2 ,
796
828
B = w2 ,
0 commit comments