@@ -737,10 +737,8 @@ def __init__(
737
737
"For FP8 Fused MoE layer, we require either per tensor or "
738
738
"channelwise, dynamic per token quantization." )
739
739
740
- from vllm .model_executor .layers .fused_moe .cutlass_moe import (
741
- cutlass_moe_fp8 )
742
740
self .topk_indices_dtype = None
743
- self .fused_experts = cutlass_moe_fp8 # type: ignore
741
+ self .fused_experts = None # type: ignore
744
742
self .disable_expert_map = False
745
743
746
744
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
@@ -936,21 +934,40 @@ def apply(
936
934
per_act_token = a1_scale .numel () != 1 if a1_scale is not None else (
937
935
a2_scale .numel () != 1 if a2_scale is not None else False )
938
936
939
- return self .fused_experts (
940
- x ,
941
- layer .w13_weight ,
942
- layer .w2_weight ,
943
- topk_weights ,
944
- topk_ids ,
945
- per_act_token = per_act_token ,
946
- activation = activation ,
947
- global_num_experts = global_num_experts ,
948
- expert_map = None if self .disable_expert_map else expert_map ,
949
- w1_scale = layer .w13_weight_scale ,
950
- w2_scale = layer .w2_weight_scale ,
951
- a1_scale = a1_scale ,
952
- a2_scale = a2_scale ,
953
- )
937
+ if self .fused_experts is None :
938
+ # If no modular kernel is provided, use cutlass_moe_fp8
939
+ from vllm .model_executor .layers .fused_moe .cutlass_moe import (
940
+ cutlass_moe_fp8 )
941
+ return cutlass_moe_fp8 (
942
+ x ,
943
+ layer .w13_weight ,
944
+ layer .w2_weight ,
945
+ topk_weights ,
946
+ topk_ids ,
947
+ per_act_token = per_act_token ,
948
+ activation = activation ,
949
+ global_num_experts = global_num_experts ,
950
+ expert_map = None if self .disable_expert_map else expert_map ,
951
+ w1_scale = layer .w13_weight_scale ,
952
+ w2_scale = layer .w2_weight_scale ,
953
+ a1_scale = a1_scale ,
954
+ a2_scale = a2_scale ,
955
+ )
956
+ else :
957
+ return self .fused_experts (
958
+ x ,
959
+ layer .w13_weight ,
960
+ layer .w2_weight ,
961
+ topk_weights ,
962
+ topk_ids ,
963
+ activation = activation ,
964
+ global_num_experts = global_num_experts ,
965
+ expert_map = None if self .disable_expert_map else expert_map ,
966
+ w1_scale = layer .w13_weight_scale ,
967
+ w2_scale = layer .w2_weight_scale ,
968
+ a1_scale = layer .w13_input_scale ,
969
+ a2_scale = layer .w2_input_scale ,
970
+ )
954
971
955
972
956
973
class CompressedTensorsW8A8Int8MoEMethod (CompressedTensorsMoEMethod ):
0 commit comments