26
26
tensor_model_parallel_all_reduce )
27
27
from vllm .distributed .parallel_state import get_dp_group
28
28
from vllm .model_executor .layers .fused_moe .layer import (
29
- FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
30
-
31
- from vllm_ascend .utils import vllm_version_is
32
-
33
- if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
34
- from vllm .model_executor .layers .fused_moe .layer import (
35
- FusedMoEParallelConfig , MoEConfig )
36
- else :
37
- MoEConfig = None
38
-
39
- from vllm .model_executor .layers .quantization .base_config import (
40
- QuantizationConfig , QuantizeMethodBase )
29
+ FusedMoE , FusedMoEParallelConfig , MoEConfig , UnquantizedFusedMoEMethod ,
30
+ determine_expert_map )
31
+ from vllm .model_executor .layers .quantization .base_config import \
32
+ QuantizationConfig
41
33
42
34
import vllm_ascend .envs as envs_ascend
43
35
from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
@@ -587,10 +579,8 @@ def select_experts(
587
579
class AscendUnquantizedFusedMoEMethod (UnquantizedFusedMoEMethod ):
588
580
589
581
def __init__ (self , moe : MoEConfig = None ):
590
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
591
- super ().__init__ ()
592
- else :
593
- super ().__init__ (moe = moe )
582
+
583
+ super ().__init__ (moe = moe )
594
584
vllm_config = get_current_vllm_config ()
595
585
596
586
ep_group = get_ep_group ()
@@ -731,24 +721,17 @@ def __init__(
731
721
params_dtype = torch .get_default_dtype ()
732
722
733
723
vllm_config = get_current_vllm_config ()
734
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
735
- self .ep_size = get_ep_group ().world_size
736
- self .tp_size = get_etp_group ().world_size
737
- self .dp_size = (dp_size if dp_size is not None else
738
- get_dp_group ().world_size )
739
- self .dp_rank = (0 if self .dp_size == 1 else
740
- get_dp_group ().rank_in_group )
741
- else :
742
- self .moe_parallel_config : FusedMoEParallelConfig = (
743
- FusedMoEParallelConfig .make (
744
- tp_size_ = (tp_size if tp_size is not None else
745
- get_tensor_model_parallel_world_size ()),
746
- dp_size_ = (dp_size if dp_size is not None else
747
- get_dp_group ().world_size ),
748
- vllm_parallel_config = vllm_config .parallel_config ))
749
724
750
- self .moe_parallel_config .ep_size = get_ep_group ().world_size
751
- self .moe_parallel_config .tp_size = get_etp_group ().world_size
725
+ self .moe_parallel_config : FusedMoEParallelConfig = (
726
+ FusedMoEParallelConfig .make (
727
+ tp_size_ = (tp_size if tp_size is not None else
728
+ get_tensor_model_parallel_world_size ()),
729
+ dp_size_ = (dp_size if dp_size is not None else
730
+ get_dp_group ().world_size ),
731
+ vllm_parallel_config = vllm_config .parallel_config ))
732
+
733
+ self .moe_parallel_config .ep_size = get_ep_group ().world_size
734
+ self .moe_parallel_config .tp_size = get_etp_group ().world_size
752
735
753
736
self .top_k = top_k
754
737
self .num_experts = num_experts
@@ -773,54 +756,39 @@ def __init__(
773
756
self .local_num_experts , self .expert_map = determine_expert_map (
774
757
self .ep_size ,
775
758
get_ep_group ().rank_in_group , self .global_num_experts )
776
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
777
- self .tp_rank = get_etp_group ().rank_in_group
778
- self .ep_rank = get_ep_group ().rank_in_group
779
- else :
780
- self .moe_parallel_config .tp_rank = get_etp_group (
781
- ).rank_in_group
782
- self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
759
+
760
+ self .moe_parallel_config .tp_rank = get_etp_group ().rank_in_group
761
+ self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
783
762
784
763
else :
785
764
# Adjust TP size for DP attention
786
765
# haven't test its functionality yet, may remove in the future
787
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
788
- self .tp_rank = self .tp_size * self .dp_rank
789
- self .ep_rank = 0
790
- self .tp_size = self .tp_size * self .dp_size
791
- self .ep_size = 1
792
- else :
793
- self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
794
- self .moe_parallel_config .ep_rank = 0
795
- self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
796
- self .moe_parallel_config .ep_size = 1
766
+
767
+ self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
768
+ self .moe_parallel_config .ep_rank = 0
769
+ self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
770
+ self .moe_parallel_config .ep_size = 1
797
771
798
772
self .local_num_experts , self .expert_map = (self .global_num_experts ,
799
773
None )
800
774
if self .scoring_func != "softmax" and not self .use_grouped_topk :
801
775
raise ValueError ("Only softmax scoring function is supported for "
802
776
"non-grouped topk." )
803
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
804
- if quant_config is None :
805
- self .quant_method : Optional [QuantizeMethodBase ] = (
806
- AscendUnquantizedFusedMoEMethod ())
807
- else :
808
- self .quant_method = quant_config .get_quant_method (self , prefix )
809
- else :
810
- moe = MoEConfig (
811
- num_experts = self .global_num_experts ,
812
- experts_per_token = top_k ,
813
- hidden_dim = hidden_size ,
814
- num_local_experts = self .local_num_experts ,
815
- moe_parallel_config = self .moe_parallel_config ,
816
- # TODO (bnell): this needs to be fixed for quantized types.
817
- in_dtype = params_dtype ,
818
- )
819
777
820
- if quant_config is None :
821
- self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
822
- else :
823
- self .quant_method = quant_config .get_quant_method (self , prefix )
778
+ moe = MoEConfig (
779
+ num_experts = self .global_num_experts ,
780
+ experts_per_token = top_k ,
781
+ hidden_dim = hidden_size ,
782
+ num_local_experts = self .local_num_experts ,
783
+ moe_parallel_config = self .moe_parallel_config ,
784
+ # TODO (bnell): this needs to be fixed for quantized types.
785
+ in_dtype = params_dtype ,
786
+ )
787
+
788
+ if quant_config is None :
789
+ self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
790
+ else :
791
+ self .quant_method = quant_config .get_quant_method (self , prefix )
824
792
825
793
assert self .quant_method is not None
826
794
0 commit comments