43
43
from vllm .model_executor .layers .layernorm import RMSNorm
44
44
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
45
45
ReplicatedLinear ,
46
- RowParallelLinear )
46
+ RowParallelLinear ,
47
+ UnquantizedLinearMethod )
47
48
from vllm .model_executor .layers .logits_processor import LogitsProcessor
48
49
from vllm .model_executor .layers .quantization import QuantizationConfig
49
50
from vllm .model_executor .layers .rotary_embedding import get_rope
75
76
MultiStreamStepMetadata ,
76
77
make_multistream_metadata_ds )
77
78
from vllm_ascend .ops .fused_moe import AscendFusedMoE
79
+ from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
78
80
from vllm_ascend .utils import dispose_tensor
79
81
80
82
VLLM_ASCEND_ENABLE_DBO : bool = envs_ascend .VLLM_ASCEND_ENABLE_DBO
81
83
82
84
83
85
class CustomDeepseekDBOMLP (CustomDeepseekV2MLP ):
84
86
87
+ def __init__ (
88
+ self ,
89
+ hidden_size : int ,
90
+ intermediate_size : int ,
91
+ hidden_act : str ,
92
+ quant_config : Optional [QuantizationConfig ] = None ,
93
+ reduce_results : bool = True ,
94
+ prefix : str = "" ,
95
+ ) -> None :
96
+ super ().__init__ (hidden_size = hidden_size ,
97
+ intermediate_size = intermediate_size ,
98
+ hidden_act = hidden_act ,
99
+ quant_config = quant_config ,
100
+ prefix = prefix )
101
+ self .is_dynamic_quant = not isinstance (
102
+ self .gate_up_proj .quant_method ,
103
+ UnquantizedLinearMethod ) and isinstance (
104
+ self .gate_up_proj .quant_method .quant_method ,
105
+ AscendW8A8DynamicLinearMethod )
106
+
85
107
def _forward_ms_mlp (self , x ):
86
108
current_ms_metadata = get_multistream_comm_context ()
87
109
assert current_ms_metadata is not None
88
110
gate_up , _ = self .gate_up_proj (x )
89
- x , dynamic_scale = self .act_fn (gate_up )
90
- x = torch_npu .npu_quant_matmul (
91
- x ,
92
- self .down_proj .weight ,
93
- self .down_proj .weight_scale ,
94
- pertoken_scale = dynamic_scale ,
95
- output_dtype = torch .bfloat16 ,
96
- )
97
- if self .down_proj .reduce_results and self .down_proj .tp_size > 1 :
98
- current_ms_metadata .before_comm_event .record ()
99
- with torch .npu .stream (current_ms_metadata .comm_stream ):
100
- current_ms_metadata .before_comm_event .wait ()
101
- x = tensor_model_parallel_all_reduce (x )
102
- current_ms_metadata .after_comm_event .record ()
111
+ if self .is_dynamic_quant :
112
+ x , dynamic_scale = self .act_fn (gate_up )
113
+ x = torch_npu .npu_quant_matmul (
114
+ x ,
115
+ self .down_proj .weight ,
116
+ self .down_proj .weight_scale ,
117
+ pertoken_scale = dynamic_scale ,
118
+ output_dtype = torch .bfloat16 ,
119
+ )
120
+ if self .down_proj .reduce_results and self .down_proj .tp_size > 1 :
121
+ current_ms_metadata .before_comm_event .record ()
122
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
123
+ current_ms_metadata .before_comm_event .wait ()
124
+ x = tensor_model_parallel_all_reduce (x )
125
+ current_ms_metadata .after_comm_event .record ()
126
+ else :
127
+ x = self .act_fn (gate_up )
128
+ x , _ = self .down_proj (x )
103
129
return x
104
130
105
131
@@ -796,6 +822,7 @@ def forward(
796
822
attn_metadata : Optional [AttentionMetadata ] = None ,
797
823
intermediate_tensors : Optional [IntermediateTensors ] = None ,
798
824
inputs_embeds : Optional [torch .Tensor ] = None ,
825
+ graph_enable : Optional [bool ] = True
799
826
) -> Union [torch .Tensor , IntermediateTensors ]:
800
827
if get_pp_group ().is_first_rank :
801
828
if inputs_embeds is not None :
@@ -809,8 +836,9 @@ def forward(
809
836
residual = intermediate_tensors ["residual" ]
810
837
811
838
num_normal_layers = (self .first_k_dense_replace
812
- if VLLM_ASCEND_ENABLE_DBO and self .can_run_ms ()
813
- else self .end_layer - self .start_layer )
839
+ if VLLM_ASCEND_ENABLE_DBO and not graph_enable
840
+ and self .can_run_ms () else self .end_layer -
841
+ self .start_layer )
814
842
815
843
moe_start_layer = self .start_layer + num_normal_layers
816
844
for i in range (self .start_layer , min (moe_start_layer , self .end_layer )):
@@ -847,15 +875,13 @@ def can_run_ms(self):
847
875
return False
848
876
return True
849
877
850
- def _forward_ms_layers (
851
- self ,
852
- positions : torch .Tensor ,
853
- hidden_states : torch .Tensor ,
854
- residual : torch .Tensor ,
855
- moe_start_layer : int ,
856
- kv_caches : Optional [List [torch .Tensor ]] = None ,
857
- is_prefill : bool = False ,
858
- ):
878
+ def _forward_ms_layers (self ,
879
+ positions : torch .Tensor ,
880
+ hidden_states : torch .Tensor ,
881
+ residual : torch .Tensor ,
882
+ moe_start_layer : int ,
883
+ kv_caches : Optional [List [torch .Tensor ]] = None ,
884
+ is_prefill : bool = False ):
859
885
860
886
if moe_start_layer == self .end_layer :
861
887
return hidden_states , residual
@@ -917,8 +943,9 @@ def forward(
917
943
attn_metadata : Optional [AttentionMetadata ] = None ,
918
944
intermediate_tensors : Optional [IntermediateTensors ] = None ,
919
945
inputs_embeds : Optional [torch .Tensor ] = None ,
946
+ graph_enable : Optional [bool ] = True
920
947
) -> Union [torch .Tensor , IntermediateTensors ]:
921
948
hidden_states = self .model (input_ids , positions , kv_caches ,
922
949
attn_metadata , intermediate_tensors ,
923
- inputs_embeds )
950
+ inputs_embeds , graph_enable )
924
951
return hidden_states
0 commit comments