55
55
from vllm .sequence import IntermediateTensors
56
56
57
57
import vllm_ascend .envs as envs_ascend
58
- from vllm_ascend .ascend_config import get_ascend_config
59
58
from vllm_ascend .distributed .tensor_parallel import gather_from_sequence_parallel_region
60
59
from vllm_ascend .ascend_forward_context import FusedMoEState
61
60
from vllm_ascend .models .deepseek_v2 import (CustomDeepseekV2DecoderLayer ,
72
71
make_multistream_metadata_ds )
73
72
from vllm_ascend .quantization .w8a8_dynamic import (
74
73
AscendW8A8DynamicLinearMethod , apply_mlp )
75
- from vllm_ascend .ops .fused_moe import AscendFusedMoE , apply_mlp , select_experts
76
- from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
74
+ from vllm_ascend .ops .fused_moe import apply_mlp , select_experts
77
75
from vllm_ascend .utils import dispose_tensor
78
76
79
77
VLLM_ASCEND_ENABLE_DBO : bool = envs_ascend .VLLM_ASCEND_ENABLE_DBO
@@ -94,7 +92,8 @@ def __init__(
94
92
intermediate_size = intermediate_size ,
95
93
hidden_act = hidden_act ,
96
94
quant_config = quant_config ,
97
- prefix = prefix )
95
+ prefix = prefix ,
96
+ reduce_results = reduce_results )
98
97
self .is_dynamic_quant = not isinstance (
99
98
self .gate_up_proj .quant_method ,
100
99
UnquantizedLinearMethod ) and isinstance (
@@ -152,19 +151,6 @@ def __init__(
152
151
prefix = f"{ prefix } .shared_experts" ,
153
152
)
154
153
CustomDeepseekDBOMoE .top_k = config .num_experts_per_tok
155
-
156
- self .dp_size = get_dp_group ().world_size
157
-
158
- self .tp_group = get_tp_group ().device_group
159
- self .tp_rank = get_tp_group ().rank_in_group
160
- self .kv_consumer = None
161
- transfer_config = get_current_vllm_config ().kv_transfer_config
162
- if transfer_config is not None :
163
- self .kv_consumer = transfer_config .kv_role = "kv_consumer"
164
- self .params_dtype = torch .get_default_dtype ()
165
-
166
- ascend_config = get_ascend_config ()
167
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
168
154
self .config = config
169
155
170
156
def forward (
@@ -196,9 +182,13 @@ def forward(
196
182
enable_force_load_balance = enable_force_load_balance ,
197
183
shared_experts = self .shared_experts )
198
184
185
+ shared_experts_hidden = experts_hidden_states [1 ]
186
+ if not (self .shared_experts .down_proj .reduce_results and self .shared_experts .down_proj .tp_size > 1 ):
187
+ shared_experts_hidden = tensor_model_parallel_all_reduce (shared_experts_hidden )
188
+
199
189
hidden_states = (
200
190
experts_hidden_states [0 ] * self .routed_scaling_factor +
201
- experts_hidden_states [ 1 ] )
191
+ shared_experts_hidden )
202
192
203
193
return hidden_states
204
194
@@ -225,18 +215,10 @@ def _forward_op_gating(
225
215
) -> torch .Tensor :
226
216
if attn_metadata is None :
227
217
attn_metadata = get_forward_context ().attn_metadata
228
- # when profile runs, force experts to load balanced tokens
229
- # to avoid high memory consumption on a single rank.
230
- # TODO: need a better flag to indicate whether in profile run or not.
231
- if attn_metadata is None :
232
- # for profile run
233
- self .is_prefill = True
234
- self .enable_force_load_balance = True
235
- else :
236
- is_prefill = attn_metadata .num_prefills > 0
237
- self .enable_force_load_balance = False
238
- if hasattr (attn_metadata , 'with_prefill_across_dp' ):
239
- self .is_prefill = is_prefill or attn_metadata .with_prefill_across_dp
218
+ # when profile runs, force experts to load balanced tokens
219
+ # to avoid high memory consumption on a single rank.
220
+ # TODO: need a better flag to indicate whether in profile run or not.
221
+ enable_force_load_balance = get_forward_context ().in_profile_run
240
222
241
223
num_tokens , hidden_dim = hidden_states .shape
242
224
@@ -291,17 +273,11 @@ def _forward_op_gating(
291
273
# this is a naive implementation for experts load balance so as
292
274
# to avoid accumulating too much tokens on a single rank.
293
275
# currently it is only activated when doing profile runs.
294
- if self . enable_force_load_balance :
276
+ if enable_force_load_balance :
295
277
topk_ids = torch .randint_like (topk_ids , 0 , self .config .n_routed_experts )
296
278
297
279
return topk_weights , topk_ids , local_hidden_states , chunked_hidden_states_sizes
298
280
299
- def _forward_dispatch_comm (
300
- self , hidden_states , topk_weights , topk_ids , microbatch_id
301
- ):
302
- token_dispatcher = self .experts .token_dispatchers [microbatch_id ]
303
- _ , hidden_states , tokens_per_expert = token_dispatcher .token_permutation (hidden_states , topk_weights , topk_ids )
304
- return hidden_states , tokens_per_expert
305
281
306
282
def _forward_op_shared_experts (
307
283
self , hidden_states
@@ -315,7 +291,7 @@ def _forward_op_grouped_mlp(
315
291
self , dispatched_input , tokens_per_expert
316
292
):
317
293
return apply_mlp (
318
- [ dispatched_input ] ,
294
+ dispatched_input ,
319
295
self .experts .w13_weight ,
320
296
self .experts .w2_weight ,
321
297
tokens_per_expert
@@ -325,8 +301,9 @@ def _forward_combine_comm(
325
301
self , hidden_states , microbatch_id , num_tokens , chunked_hidden_states_sizes
326
302
):
327
303
token_dispatcher = self .experts .token_dispatchers [microbatch_id ]
328
- token_dispatcher .combine_alltoall ()
329
- final_hidden_states = token_dispatcher .unpermute2 () * self .routed_scaling_factor
304
+ final_hidden_states , _ = token_dispatcher .token_unpermutation (hidden_states )
305
+ if hasattr (self , 'routed_scaling_factor' ):
306
+ final_hidden_states = final_hidden_states * self .routed_scaling_factor
330
307
331
308
if self .tp_size > 1 :
332
309
final_hidden_states = gather_from_sequence_parallel_region (final_hidden_states , self .tp_group ,
@@ -794,17 +771,12 @@ def _forward_ms_layer_alltoallv_finegrained(
794
771
chunked_hidden_states_sizes = [None ] * num_micro_batchs
795
772
token_dispatchers = self .mlp .experts .token_dispatchers
796
773
797
- def print_with_sync (* args , ** kwargs ):
798
- torch .npu .synchronize ()
799
- print (* args , ** kwargs )
800
-
801
774
def discard_tensor (tensor ):
802
775
if isinstance (tensor , torch .Tensor ):
803
776
tensor = [tensor ]
804
777
for t in tensor :
805
778
t .untyped_storage ().resize_ (0 )
806
779
807
- # print_with_sync('begin layer...', torch.distributed.get_rank())
808
780
809
781
# block 1 : attention
810
782
# block 2 : Router Gating
@@ -814,12 +786,11 @@ def discard_tensor(tensor):
814
786
# can be overlapped with the attn communication of microbatch 1
815
787
for i in range (num_micro_batchs ):
816
788
# wait last layer moe finishing communication
817
- ms_metadata .try_wait_event (layer_index - 1 , i ,
818
- MSEventKey .MOE_AFTER_COMM )
819
789
820
790
forward_context = get_forward_context ()
821
791
layer_index , ms_metadata , attn_metadata = get_multistream_layer_context (
822
792
)
793
+ ms_metadata .try_wait_event (layer_index - 1 , i , MSEventKey .FFN_AR_FINISH )
823
794
forward_context .attn_metadata = attn_metadata [i ]
824
795
825
796
# input layernorm
@@ -856,9 +827,10 @@ def discard_tensor(tensor):
856
827
with torch .npu .stream (dispatch_context .comm_stream ):
857
828
dispatch_context .comm_stream .wait_event (dispatch_context .before_comm_event )
858
829
token_dispatchers [i ].dispatch_alltoall ()
830
+ dispatched_input [i ], tokens_per_expert [i ] = token_dispatchers [i ].permute2 ()
859
831
dispatch_context .after_comm_event .record ()
860
832
861
- if self .mlp .n_shared_experts :
833
+ if self .mlp .n_shared_experts and self . tp_size > 1 :
862
834
token_dispatchers [i ].cached_shared_expert_output = tensor_model_parallel_all_reduce (
863
835
token_dispatchers [i ].cached_shared_expert_output
864
836
)
@@ -872,20 +844,16 @@ def discard_tensor(tensor):
872
844
ms_metadata .try_wait_event (layer_index , i , MSEventKey .MOE_AFTER_COMM )
873
845
discard_tensor (hidden_states [i ])
874
846
875
- dispatched_input [i ], tokens_per_expert [i ] = token_dispatchers [i ].permute2 ()
876
847
router_expert_output [i ] = self .mlp ._forward_op_grouped_mlp (dispatched_input [i ], tokens_per_expert [i ])
877
848
discard_tensor (dispatched_input [i ])
878
- token_dispatchers [i ].unpermute1 (router_expert_output [i ])
879
- if router_expert_output [i ].shape [0 ] > 0 and token_dispatchers [i ].num_local_experts > 1 :
880
- discard_tensor (router_expert_output [i ])
881
849
882
850
# Launch Combine Comm in a New Stream.
883
851
combine_context = MultiStreamStepMetadata (
884
852
comm_stream = ms_metadata .communicate_stream ,
885
853
before_comm_event = ms_metadata .ms_events [layer_index ][i ][
886
- MSEventKey .MOE_BEFORE_COMM ],
854
+ MSEventKey .FFN_COM_FINISH ],
887
855
after_comm_event = ms_metadata .ms_events [layer_index ][i ][
888
- MSEventKey .MOE_AFTER_COMM ],
856
+ MSEventKey .FFN_AR_FINISH ],
889
857
)
890
858
combine_context .before_comm_event .record ()
891
859
ms_metadata .try_wait_event (layer_index , i , MSEventKey .MOE_SE_COMM_FINISH )
@@ -1032,7 +1000,6 @@ def forward(
1032
1000
if VLLM_ASCEND_ENABLE_DBO and not graph_enable
1033
1001
and self .can_run_ms () else self .end_layer -
1034
1002
self .start_layer )
1035
-
1036
1003
moe_start_layer = self .start_layer + num_normal_layers
1037
1004
for i in range (self .start_layer , min (moe_start_layer , self .end_layer )):
1038
1005
layer = self .layers [i ]
@@ -1068,16 +1035,6 @@ def can_run_ms(self):
1068
1035
return False
1069
1036
return True
1070
1037
1071
- def all_can_run_ms (self ):
1072
- can_run_ms_local = self .can_run_ms ()
1073
- ep_group = get_ep_group ().cpu_group
1074
- flag = torch .ones (1 , dtype = torch .int ) if can_run_ms_local else torch .zeros (1 , dtype = torch .int )
1075
- torch .distributed .all_reduce (flag , group = ep_group )
1076
- if flag .item () == torch .distributed .get_world_size (ep_group ):
1077
- return True
1078
- else :
1079
- return False
1080
-
1081
1038
def _forward_ms_layers (self ,
1082
1039
positions : torch .Tensor ,
1083
1040
hidden_states : torch .Tensor ,
@@ -1098,9 +1055,7 @@ def _forward_ms_layers(self,
1098
1055
layer = self .layers [i ]
1099
1056
ms_layer_forward_func = layer ._forward_ms_layer
1100
1057
if fused_moe_state == FusedMoEState .All2AllSeq :
1101
- # ms_layer_forward_func = layer._forward_ms_layer_alltoallv
1102
1058
ms_layer_forward_func = layer ._forward_ms_layer_alltoallv_finegrained
1103
- # print("get_called......")
1104
1059
hidden_states , residual = ms_layer_forward_func (
1105
1060
positions = positions ,
1106
1061
hidden_states = hidden_states ,
0 commit comments