Skip to content

Commit 1bd5be9

Browse files
sharonyunyunsdmyzlp
authored andcommitted
adjusting the communication method in graph mode (#1194)
Signed-off-by: sharonyunyun <zhangying134@huawei.com> Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 066ea10 commit 1bd5be9

File tree

3 files changed

+161
-28
lines changed

3 files changed

+161
-28
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.distributed import get_tensor_model_parallel_world_size
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
@@ -584,6 +585,7 @@ def __init__(
584585
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
585586
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
586587
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
588+
self.tp_size = get_tensor_model_parallel_world_size()
587589

588590
ascend_config = get_ascend_config()
589591
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -602,7 +604,7 @@ def _v_up_proj_and_o_proj(self, x):
602604
x = torch.bmm(x, self.W_UV)
603605
# Convert from (N, B, V) to (B, N * V)
604606
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
605-
return self.o_proj(x)[0]
607+
return self.o_proj(x, is_prefill=False)[0]
606608

607609
# Return `ql_nope`, `q_pe`
608610
def _q_proj_and_k_up_proj(self, x):
@@ -867,12 +869,12 @@ def _forward_prefill(
867869

868870
current_ms_metadata = get_multistream_comm_context()
869871
if current_ms_metadata is None:
870-
return self.o_proj(attn_output)[0]
872+
return self.o_proj(attn_output, is_prefill=True)[0]
871873
else:
872874
current_ms_metadata.before_comm_event.record()
873875
with torch.npu.stream(current_ms_metadata.comm_stream):
874876
current_ms_metadata.before_comm_event.wait()
875-
return self.o_proj(attn_output)[0]
877+
return self.o_proj(attn_output, is_prefill=True)[0]
876878

877879
def exec_kv(
878880
self,

vllm_ascend/models/deepseek_v2.py

Lines changed: 150 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,16 @@
3535
from vllm.attention import Attention, AttentionMetadata
3636
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3737
get_current_vllm_config)
38+
# Temporarily disable yapf since it conflicts with isort.
39+
# yapf: disable
3840
from vllm.distributed import (get_dp_group, get_pp_group,
41+
get_tensor_model_parallel_rank,
3942
get_tensor_model_parallel_world_size,
40-
get_tp_group)
43+
get_tp_group, split_tensor_along_last_dim,
44+
tensor_model_parallel_all_gather,
45+
tensor_model_parallel_all_reduce,
46+
tensor_model_parallel_reduce_scatter)
47+
# yapf: enable
4148
from vllm.forward_context import get_forward_context
4249
from vllm.model_executor.layers.activation import SiluAndMul
4350
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -132,6 +139,80 @@ def weight_loader(self, param: torch.nn.Parameter,
132139
shard.copy_(loaded_weight)
133140

134141

142+
class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
143+
144+
def forward(
145+
self,
146+
input_,
147+
is_prefill=True
148+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
149+
if self.input_is_parallel:
150+
input_parallel = input_
151+
else:
152+
tp_rank = get_tensor_model_parallel_rank()
153+
splitted_input = split_tensor_along_last_dim(
154+
input_, num_partitions=self.tp_size)
155+
input_parallel = splitted_input[tp_rank].contiguous()
156+
157+
# Matrix multiply.
158+
assert self.quant_method is not None
159+
# Only fuse bias add into GEMM for rank 0 (this ensures that
160+
# bias will not get added more than once in TP>1 case)
161+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
162+
output_parallel = self.quant_method.apply(self,
163+
input_parallel,
164+
bias=bias_)
165+
if self.reduce_results and self.tp_size > 1:
166+
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
167+
output = tensor_model_parallel_reduce_scatter(output_parallel,
168+
dim=0)
169+
else:
170+
output = tensor_model_parallel_all_reduce(output_parallel)
171+
else:
172+
output = output_parallel
173+
174+
output_bias = self.bias if self.skip_bias_add else None
175+
176+
if not self.return_bias:
177+
return output
178+
return output, output_bias
179+
180+
181+
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
182+
183+
def forward(
184+
self,
185+
input_,
186+
is_prefill=True
187+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
188+
if self.input_is_parallel:
189+
input_parallel = input_
190+
else:
191+
tp_rank = get_tensor_model_parallel_rank()
192+
splitted_input = split_tensor_along_last_dim(
193+
input_, num_partitions=self.tp_size)
194+
input_parallel = splitted_input[tp_rank].contiguous()
195+
196+
# Matrix multiply.
197+
assert self.quant_method is not None
198+
# Only fuse bias add into GEMM for rank 0 (this ensures that
199+
# bias will not get added more than once in TP>1 case)
200+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
201+
output_parallel = self.quant_method.apply(self,
202+
input_parallel,
203+
bias=bias_)
204+
if self.reduce_results and self.tp_size > 1:
205+
output = tensor_model_parallel_all_reduce(output_parallel)
206+
else:
207+
output = output_parallel
208+
209+
output_bias = self.bias if self.skip_bias_add else None
210+
211+
if not self.return_bias:
212+
return output
213+
return output, output_bias
214+
215+
135216
class CustomDeepseekV2MLP(nn.Module):
136217

137218
def __init__(
@@ -291,10 +372,10 @@ def __init__(
291372

292373
self.params_dtype = torch.get_default_dtype()
293374

294-
def forward(
295-
self,
296-
hidden_states: torch.Tensor,
297-
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
375+
def forward(self,
376+
hidden_states: torch.Tensor,
377+
attn_metadata: Optional[AttentionMetadata] = None,
378+
replace_allreduce: bool = False) -> torch.Tensor:
298379
forward_context = get_forward_context()
299380
if attn_metadata is None:
300381
attn_metadata = forward_context.attn_metadata
@@ -323,7 +404,7 @@ def forward(
323404
enable_force_load_balance=enable_force_load_balance,
324405
shared_experts=self.shared_experts,
325406
gate=self.gate if self.enable_multistream_moe else None,
326-
)
407+
replace_allreduce=replace_allreduce)
327408

328409
hidden_states = (
329410
experts_hidden_states[0] * self.routed_scaling_factor +
@@ -370,6 +451,14 @@ def __init__(
370451
self.rope_theta = rope_theta
371452
self.max_position_embeddings = max_position_embeddings
372453

454+
self.prefix = prefix
455+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
456+
457+
ascend_config = get_ascend_config()
458+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
459+
self.enable_multistream_mla = \
460+
ascend_config.torchair_graph_config.enable_multistream_mla
461+
373462
if self.q_lora_rank is not None:
374463
self.q_a_proj = ReplicatedLinear(self.hidden_size,
375464
self.q_lora_rank,
@@ -406,11 +495,23 @@ def __init__(
406495
bias=False,
407496
quant_config=quant_config,
408497
prefix=f"{prefix}.kv_b_proj")
409-
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
410-
self.hidden_size,
411-
bias=False,
412-
quant_config=quant_config,
413-
prefix=f"{prefix}.o_proj")
498+
if (config.n_routed_experts is not None
499+
and self.debug_layer_idx >= config.first_k_dense_replace
500+
and self.debug_layer_idx % config.moe_layer_freq == 0 and
501+
ascend_config.torchair_graph_config.enable_multistream_moe):
502+
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
503+
self.num_heads * self.v_head_dim,
504+
self.hidden_size,
505+
bias=False,
506+
quant_config=quant_config,
507+
prefix=f"{prefix}.o_proj")
508+
else:
509+
self.o_proj = CustomDeepseekV2RowParallelLinear(
510+
self.num_heads * self.v_head_dim,
511+
self.hidden_size,
512+
bias=False,
513+
quant_config=quant_config,
514+
prefix=f"{prefix}.o_proj")
414515

415516
if rope_scaling:
416517
rope_scaling["rope_type"] = 'deepseek_yarn'
@@ -456,14 +557,6 @@ def __init__(
456557
o_proj=self.o_proj,
457558
)
458559

459-
self.prefix = prefix
460-
self.debug_layer_idx = int(self.prefix.split(".")[-2])
461-
462-
ascend_config = get_ascend_config()
463-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
464-
self.enable_multistream_mla = \
465-
ascend_config.torchair_graph_config.enable_multistream_mla
466-
467560
def forward(
468561
self,
469562
positions: torch.Tensor,
@@ -530,6 +623,10 @@ def __init__(
530623
# with the layer's index.
531624
layer_idx = int(prefix.split(sep='.')[-1])
532625
self.layer_idx = layer_idx
626+
self.layers = config.num_hidden_layers
627+
self.tp_size = get_tensor_model_parallel_world_size()
628+
self.tp_rank = get_tp_group().rank_in_group
629+
ascend_config = get_ascend_config()
533630
# TODO: enable mla in vllm-ascend
534631
if model_config.use_mla:
535632
attn_cls = CustomDeepseekV2MLAAttention
@@ -561,6 +658,8 @@ def __init__(
561658
quant_config=quant_config,
562659
prefix=f"{prefix}.mlp",
563660
)
661+
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
662+
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
564663
else:
565664
self.mlp = CustomDeepseekV2MLP(
566665
hidden_size=config.hidden_size,
@@ -569,11 +668,13 @@ def __init__(
569668
quant_config=quant_config,
570669
prefix=f"{prefix}.mlp",
571670
)
671+
self.mla_moe_communication = False
572672
self.input_layernorm = RMSNorm(config.hidden_size,
573673
eps=config.rms_norm_eps)
574674
self.post_attention_layernorm = RMSNorm(config.hidden_size,
575675
eps=config.rms_norm_eps)
576676
self.routed_scaling_factor = config.routed_scaling_factor
677+
self.first_k_dense_replace = config.first_k_dense_replace
577678

578679
def forward(
579680
self,
@@ -582,8 +683,13 @@ def forward(
582683
residual: Optional[torch.Tensor],
583684
kv_cache: Optional[torch.Tensor] = None,
584685
attn_metadata: Optional[AttentionMetadata] = None,
686+
replace_allreduce: bool = False,
585687
) -> torch.Tensor:
586688
# Self Attention
689+
if attn_metadata is not None and attn_metadata.num_decodes > 0:
690+
mla_moe_communication = self.mla_moe_communication and replace_allreduce
691+
else:
692+
mla_moe_communication = False
587693
if residual is None:
588694
residual = hidden_states
589695
hidden_states = self.input_layernorm(hidden_states)
@@ -595,6 +701,9 @@ def forward(
595701
# to save npu memory because they're no longer used.
596702
dispose_tensor(previous_hidden_states)
597703
dispose_tensor(previous_residual)
704+
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
705+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
706+
dim=0)
598707

599708
hidden_states = self.self_attn(
600709
positions=positions,
@@ -603,6 +712,13 @@ def forward(
603712
attn_metadata=attn_metadata,
604713
)
605714

715+
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
716+
0]:
717+
chunk_hidden_states = torch.tensor_split(residual,
718+
self.tp_size,
719+
dim=0)
720+
residual = chunk_hidden_states[self.tp_rank]
721+
606722
if hidden_states.dtype == torch.float16:
607723
# Fix FP16 overflow
608724
# We scale both hidden_states and residual before
@@ -618,7 +734,9 @@ def forward(
618734
hidden_states, residual)
619735

620736
if isinstance(self.mlp, CustomDeepseekV2MoE):
621-
hidden_states = self.mlp(hidden_states, attn_metadata)
737+
hidden_states = self.mlp(hidden_states,
738+
attn_metadata,
739+
replace_allreduce=mla_moe_communication)
622740
else:
623741
hidden_states = self.mlp(hidden_states)
624742

@@ -631,6 +749,10 @@ def forward(
631749
# The scaling of DeepseekV2MOE output would be done in the forward
632750
# of DeepseekV2MOE
633751
hidden_states *= 1. / self.routed_scaling_factor
752+
if mla_moe_communication and self.layer_idx == self.layers - 1:
753+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
754+
dim=0)
755+
residual = tensor_model_parallel_all_gather(residual, dim=0)
634756

635757
return hidden_states, residual
636758

@@ -649,6 +771,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
649771

650772
self.padding_idx = config.pad_token_id
651773
self.vocab_size = config.vocab_size
774+
self.tp_size = get_tensor_model_parallel_world_size()
652775

653776
if get_pp_group().is_first_rank:
654777
self.embed_tokens = VocabParallelEmbedding(
@@ -701,13 +824,18 @@ def forward(
701824
hidden_states = intermediate_tensors["hidden_states"]
702825
residual = intermediate_tensors["residual"]
703826

827+
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
828+
704829
for i in range(self.start_layer, self.end_layer):
705830
layer = self.layers[i]
706831
hidden_states, residual = layer(
707-
positions, hidden_states, residual,
832+
positions,
833+
hidden_states,
834+
residual,
708835
kv_caches[i -
709836
self.start_layer] if kv_caches is not None else None,
710-
attn_metadata)
837+
attn_metadata,
838+
replace_allreduce=replace_allreduce)
711839

712840
if not get_pp_group().is_last_rank:
713841
return IntermediateTensors({

vllm_ascend/ops/fused_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,8 @@ def forward(self,
11261126
enable_force_load_balance: bool = False,
11271127
top_k: Optional[int] = None,
11281128
shared_experts: Optional[Any] = None,
1129-
gate: Optional[Any] = None):
1129+
gate: Optional[Any] = None,
1130+
replace_allreduce: bool = False):
11301131
assert self.quant_method is not None
11311132

11321133
if top_k:
@@ -1158,7 +1159,8 @@ def forward(self,
11581159
shared_hidden_states = shared_experts(hidden_states)
11591160

11601161
tp_size = get_tensor_model_parallel_world_size()
1161-
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
1162+
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
1163+
and not replace_allreduce):
11621164
if num_tokens < tp_size:
11631165
hidden_states = nn.functional.pad(
11641166
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1217,7 +1219,8 @@ def forward(self,
12171219
if isinstance(e_hidden_states, tuple):
12181220
e_hidden_states, shared_hidden_states = e_hidden_states
12191221

1220-
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
1222+
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
1223+
and not replace_allreduce):
12211224
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
12221225
self.tp_group)
12231226
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)

0 commit comments

Comments
 (0)