@@ -633,9 +633,10 @@ def apply(
633
633
logical_replica_count : Optional [torch .Tensor ] = None ,
634
634
) -> torch .Tensor :
635
635
if enable_eplb :
636
- raise NotImplementedError (
637
- "EPLB not supported for "
638
- "`CompressedTensorsW8A8Fp8MoEMethod` yet." )
636
+ assert expert_load_view is not None
637
+ assert logical_to_physical_map is not None
638
+ assert logical_replica_count is not None
639
+ assert isinstance (layer , FusedMoE )
639
640
640
641
topk_weights , topk_ids = FusedMoE .select_experts (
641
642
hidden_states = x ,
@@ -649,6 +650,11 @@ def apply(
649
650
scoring_func = scoring_func ,
650
651
e_score_correction_bias = e_score_correction_bias ,
651
652
indices_type = self .topk_indices_dtype ,
653
+ enable_eplb = enable_eplb ,
654
+ expert_map = expert_map ,
655
+ expert_load_view = expert_load_view ,
656
+ logical_to_physical_map = logical_to_physical_map ,
657
+ logical_replica_count = logical_replica_count ,
652
658
)
653
659
654
660
if self .rocm_aiter_moe_enabled :
@@ -913,9 +919,10 @@ def apply(
913
919
logical_replica_count : Optional [torch .Tensor ] = None ,
914
920
) -> torch .Tensor :
915
921
if enable_eplb :
916
- raise NotImplementedError (
917
- "EPLB not supported for "
918
- "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet." )
922
+ assert expert_load_view is not None
923
+ assert logical_to_physical_map is not None
924
+ assert logical_replica_count is not None
925
+ assert isinstance (layer , FusedMoE )
919
926
920
927
topk_weights , topk_ids = FusedMoE .select_experts (
921
928
hidden_states = x ,
@@ -927,7 +934,12 @@ def apply(
927
934
num_expert_group = num_expert_group ,
928
935
custom_routing_function = custom_routing_function ,
929
936
scoring_func = scoring_func ,
930
- e_score_correction_bias = e_score_correction_bias )
937
+ e_score_correction_bias = e_score_correction_bias ,
938
+ enable_eplb = enable_eplb ,
939
+ expert_map = expert_map ,
940
+ expert_load_view = expert_load_view ,
941
+ logical_to_physical_map = logical_to_physical_map ,
942
+ logical_replica_count = logical_replica_count )
931
943
932
944
a1_scale = layer .w13_input_scale
933
945
a2_scale = layer .w2_input_scale
0 commit comments