35
35
from vllm .attention import Attention , AttentionMetadata
36
36
from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
37
37
get_current_vllm_config )
38
+ # Temporarily disable yapf since it conflicts with isort.
39
+ # yapf: disable
38
40
from vllm .distributed import (get_dp_group , get_pp_group ,
41
+ get_tensor_model_parallel_rank ,
39
42
get_tensor_model_parallel_world_size ,
40
- get_tp_group )
43
+ get_tp_group , split_tensor_along_last_dim ,
44
+ tensor_model_parallel_all_gather ,
45
+ tensor_model_parallel_all_reduce ,
46
+ tensor_model_parallel_reduce_scatter )
47
+ # yapf: enable
41
48
from vllm .forward_context import get_forward_context
42
49
from vllm .model_executor .layers .activation import SiluAndMul
43
50
from vllm .model_executor .layers .layernorm import RMSNorm
@@ -132,6 +139,80 @@ def weight_loader(self, param: torch.nn.Parameter,
132
139
shard .copy_ (loaded_weight )
133
140
134
141
142
+ class CustomDeepseekV2RowParallelLinearReplaceAllreduce (RowParallelLinear ):
143
+
144
+ def forward (
145
+ self ,
146
+ input_ ,
147
+ is_prefill = True
148
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [nn .Parameter ]]]:
149
+ if self .input_is_parallel :
150
+ input_parallel = input_
151
+ else :
152
+ tp_rank = get_tensor_model_parallel_rank ()
153
+ splitted_input = split_tensor_along_last_dim (
154
+ input_ , num_partitions = self .tp_size )
155
+ input_parallel = splitted_input [tp_rank ].contiguous ()
156
+
157
+ # Matrix multiply.
158
+ assert self .quant_method is not None
159
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
160
+ # bias will not get added more than once in TP>1 case)
161
+ bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
162
+ output_parallel = self .quant_method .apply (self ,
163
+ input_parallel ,
164
+ bias = bias_ )
165
+ if self .reduce_results and self .tp_size > 1 :
166
+ if not is_prefill and output_parallel .shape [0 ] % self .tp_size == 0 :
167
+ output = tensor_model_parallel_reduce_scatter (output_parallel ,
168
+ dim = 0 )
169
+ else :
170
+ output = tensor_model_parallel_all_reduce (output_parallel )
171
+ else :
172
+ output = output_parallel
173
+
174
+ output_bias = self .bias if self .skip_bias_add else None
175
+
176
+ if not self .return_bias :
177
+ return output
178
+ return output , output_bias
179
+
180
+
181
+ class CustomDeepseekV2RowParallelLinear (RowParallelLinear ):
182
+
183
+ def forward (
184
+ self ,
185
+ input_ ,
186
+ is_prefill = True
187
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [nn .Parameter ]]]:
188
+ if self .input_is_parallel :
189
+ input_parallel = input_
190
+ else :
191
+ tp_rank = get_tensor_model_parallel_rank ()
192
+ splitted_input = split_tensor_along_last_dim (
193
+ input_ , num_partitions = self .tp_size )
194
+ input_parallel = splitted_input [tp_rank ].contiguous ()
195
+
196
+ # Matrix multiply.
197
+ assert self .quant_method is not None
198
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
199
+ # bias will not get added more than once in TP>1 case)
200
+ bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
201
+ output_parallel = self .quant_method .apply (self ,
202
+ input_parallel ,
203
+ bias = bias_ )
204
+ if self .reduce_results and self .tp_size > 1 :
205
+ output = tensor_model_parallel_all_reduce (output_parallel )
206
+ else :
207
+ output = output_parallel
208
+
209
+ output_bias = self .bias if self .skip_bias_add else None
210
+
211
+ if not self .return_bias :
212
+ return output
213
+ return output , output_bias
214
+
215
+
135
216
class CustomDeepseekV2MLP (nn .Module ):
136
217
137
218
def __init__ (
@@ -291,10 +372,10 @@ def __init__(
291
372
292
373
self .params_dtype = torch .get_default_dtype ()
293
374
294
- def forward (
295
- self ,
296
- hidden_states : torch . Tensor ,
297
- attn_metadata : Optional [ AttentionMetadata ] = None ) -> torch .Tensor :
375
+ def forward (self ,
376
+ hidden_states : torch . Tensor ,
377
+ attn_metadata : Optional [ AttentionMetadata ] = None ,
378
+ replace_allreduce : bool = False ) -> torch .Tensor :
298
379
forward_context = get_forward_context ()
299
380
if attn_metadata is None :
300
381
attn_metadata = forward_context .attn_metadata
@@ -323,7 +404,7 @@ def forward(
323
404
enable_force_load_balance = enable_force_load_balance ,
324
405
shared_experts = self .shared_experts ,
325
406
gate = self .gate if self .enable_multistream_moe else None ,
326
- )
407
+ replace_allreduce = replace_allreduce )
327
408
328
409
hidden_states = (
329
410
experts_hidden_states [0 ] * self .routed_scaling_factor +
@@ -370,6 +451,14 @@ def __init__(
370
451
self .rope_theta = rope_theta
371
452
self .max_position_embeddings = max_position_embeddings
372
453
454
+ self .prefix = prefix
455
+ self .debug_layer_idx = int (self .prefix .split ("." )[- 2 ])
456
+
457
+ ascend_config = get_ascend_config ()
458
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
459
+ self .enable_multistream_mla = \
460
+ ascend_config .torchair_graph_config .enable_multistream_mla
461
+
373
462
if self .q_lora_rank is not None :
374
463
self .q_a_proj = ReplicatedLinear (self .hidden_size ,
375
464
self .q_lora_rank ,
@@ -406,11 +495,23 @@ def __init__(
406
495
bias = False ,
407
496
quant_config = quant_config ,
408
497
prefix = f"{ prefix } .kv_b_proj" )
409
- self .o_proj = RowParallelLinear (self .num_heads * self .v_head_dim ,
410
- self .hidden_size ,
411
- bias = False ,
412
- quant_config = quant_config ,
413
- prefix = f"{ prefix } .o_proj" )
498
+ if (config .n_routed_experts is not None
499
+ and self .debug_layer_idx >= config .first_k_dense_replace
500
+ and self .debug_layer_idx % config .moe_layer_freq == 0 and
501
+ ascend_config .torchair_graph_config .enable_multistream_moe ):
502
+ self .o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce (
503
+ self .num_heads * self .v_head_dim ,
504
+ self .hidden_size ,
505
+ bias = False ,
506
+ quant_config = quant_config ,
507
+ prefix = f"{ prefix } .o_proj" )
508
+ else :
509
+ self .o_proj = CustomDeepseekV2RowParallelLinear (
510
+ self .num_heads * self .v_head_dim ,
511
+ self .hidden_size ,
512
+ bias = False ,
513
+ quant_config = quant_config ,
514
+ prefix = f"{ prefix } .o_proj" )
414
515
415
516
if rope_scaling :
416
517
rope_scaling ["rope_type" ] = 'deepseek_yarn'
@@ -456,14 +557,6 @@ def __init__(
456
557
o_proj = self .o_proj ,
457
558
)
458
559
459
- self .prefix = prefix
460
- self .debug_layer_idx = int (self .prefix .split ("." )[- 2 ])
461
-
462
- ascend_config = get_ascend_config ()
463
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
464
- self .enable_multistream_mla = \
465
- ascend_config .torchair_graph_config .enable_multistream_mla
466
-
467
560
def forward (
468
561
self ,
469
562
positions : torch .Tensor ,
@@ -530,6 +623,10 @@ def __init__(
530
623
# with the layer's index.
531
624
layer_idx = int (prefix .split (sep = '.' )[- 1 ])
532
625
self .layer_idx = layer_idx
626
+ self .layers = config .num_hidden_layers
627
+ self .tp_size = get_tensor_model_parallel_world_size ()
628
+ self .tp_rank = get_tp_group ().rank_in_group
629
+ ascend_config = get_ascend_config ()
533
630
# TODO: enable mla in vllm-ascend
534
631
if model_config .use_mla :
535
632
attn_cls = CustomDeepseekV2MLAAttention
@@ -561,6 +658,8 @@ def __init__(
561
658
quant_config = quant_config ,
562
659
prefix = f"{ prefix } .mlp" ,
563
660
)
661
+ self .mla_moe_communication = ascend_config .torchair_graph_config .enable_multistream_moe \
662
+ and model_config .use_mla and envs .VLLM_USE_V1 and self .tp_size > 1
564
663
else :
565
664
self .mlp = CustomDeepseekV2MLP (
566
665
hidden_size = config .hidden_size ,
@@ -569,11 +668,13 @@ def __init__(
569
668
quant_config = quant_config ,
570
669
prefix = f"{ prefix } .mlp" ,
571
670
)
671
+ self .mla_moe_communication = False
572
672
self .input_layernorm = RMSNorm (config .hidden_size ,
573
673
eps = config .rms_norm_eps )
574
674
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
575
675
eps = config .rms_norm_eps )
576
676
self .routed_scaling_factor = config .routed_scaling_factor
677
+ self .first_k_dense_replace = config .first_k_dense_replace
577
678
578
679
def forward (
579
680
self ,
@@ -582,8 +683,13 @@ def forward(
582
683
residual : Optional [torch .Tensor ],
583
684
kv_cache : Optional [torch .Tensor ] = None ,
584
685
attn_metadata : Optional [AttentionMetadata ] = None ,
686
+ replace_allreduce : bool = False ,
585
687
) -> torch .Tensor :
586
688
# Self Attention
689
+ if attn_metadata is not None and attn_metadata .num_decodes > 0 :
690
+ mla_moe_communication = self .mla_moe_communication and replace_allreduce
691
+ else :
692
+ mla_moe_communication = False
587
693
if residual is None :
588
694
residual = hidden_states
589
695
hidden_states = self .input_layernorm (hidden_states )
@@ -595,6 +701,9 @@ def forward(
595
701
# to save npu memory because they're no longer used.
596
702
dispose_tensor (previous_hidden_states )
597
703
dispose_tensor (previous_residual )
704
+ if mla_moe_communication and self .layer_idx > self .first_k_dense_replace :
705
+ hidden_states = tensor_model_parallel_all_gather (hidden_states ,
706
+ dim = 0 )
598
707
599
708
hidden_states = self .self_attn (
600
709
positions = positions ,
@@ -603,6 +712,13 @@ def forward(
603
712
attn_metadata = attn_metadata ,
604
713
)
605
714
715
+ if mla_moe_communication and residual .shape [0 ] != hidden_states .shape [
716
+ 0 ]:
717
+ chunk_hidden_states = torch .tensor_split (residual ,
718
+ self .tp_size ,
719
+ dim = 0 )
720
+ residual = chunk_hidden_states [self .tp_rank ]
721
+
606
722
if hidden_states .dtype == torch .float16 :
607
723
# Fix FP16 overflow
608
724
# We scale both hidden_states and residual before
@@ -618,7 +734,9 @@ def forward(
618
734
hidden_states , residual )
619
735
620
736
if isinstance (self .mlp , CustomDeepseekV2MoE ):
621
- hidden_states = self .mlp (hidden_states , attn_metadata )
737
+ hidden_states = self .mlp (hidden_states ,
738
+ attn_metadata ,
739
+ replace_allreduce = mla_moe_communication )
622
740
else :
623
741
hidden_states = self .mlp (hidden_states )
624
742
@@ -631,6 +749,10 @@ def forward(
631
749
# The scaling of DeepseekV2MOE output would be done in the forward
632
750
# of DeepseekV2MOE
633
751
hidden_states *= 1. / self .routed_scaling_factor
752
+ if mla_moe_communication and self .layer_idx == self .layers - 1 :
753
+ hidden_states = tensor_model_parallel_all_gather (hidden_states ,
754
+ dim = 0 )
755
+ residual = tensor_model_parallel_all_gather (residual , dim = 0 )
634
756
635
757
return hidden_states , residual
636
758
@@ -649,6 +771,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
649
771
650
772
self .padding_idx = config .pad_token_id
651
773
self .vocab_size = config .vocab_size
774
+ self .tp_size = get_tensor_model_parallel_world_size ()
652
775
653
776
if get_pp_group ().is_first_rank :
654
777
self .embed_tokens = VocabParallelEmbedding (
@@ -701,13 +824,18 @@ def forward(
701
824
hidden_states = intermediate_tensors ["hidden_states" ]
702
825
residual = intermediate_tensors ["residual" ]
703
826
827
+ replace_allreduce = hidden_states .shape [0 ] % self .tp_size == 0
828
+
704
829
for i in range (self .start_layer , self .end_layer ):
705
830
layer = self .layers [i ]
706
831
hidden_states , residual = layer (
707
- positions , hidden_states , residual ,
832
+ positions ,
833
+ hidden_states ,
834
+ residual ,
708
835
kv_caches [i -
709
836
self .start_layer ] if kv_caches is not None else None ,
710
- attn_metadata )
837
+ attn_metadata ,
838
+ replace_allreduce = replace_allreduce )
711
839
712
840
if not get_pp_group ().is_last_rank :
713
841
return IntermediateTensors ({
0 commit comments