Skip to content

Commit cdae151

Browse files
committed
avoid performing index selection of sin/cos cache every layer
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e878d56 commit cdae151

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from vllm.distributed import get_dp_group
88
from vllm.forward_context import get_forward_context, set_forward_context
99

10+
from vllm_ascend.ascend_config import get_ascend_config
11+
1012

1113
class FusedMoEState(Enum):
1214
AllGather = 0
@@ -55,6 +57,15 @@ def set_ascend_forward_context(
5557

5658
forward_context.in_profile_run = in_profile_run
5759

60+
ascend_config = get_ascend_config()
61+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
62+
forward_context.running_in_graph = ascend_config.torchair_graph_config.enabled and \
63+
attn_metadata and \
64+
attn_metadata.attn_state in [
65+
AscendAttentionState.DecodeOnly,
66+
AscendAttentionState.SpecDecoding
67+
]
68+
5869
dp_world_size = get_dp_group().world_size
5970
if dp_world_size > 1 and forward_context.dp_metadata is not None:
6071
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(

vllm_ascend/attention/mla_v1.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.forward_context import get_forward_context
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
@@ -1042,9 +1043,7 @@ def forward(
10421043
if attn_metadata is None:
10431044
# Profiling run.
10441045
return output
1045-
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
1046-
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
1047-
]
1046+
self.running_in_graph = get_forward_context().running_in_graph
10481047
num_actual_toks = attn_metadata.num_actual_tokens
10491048
if k_pe is None and not self.running_in_graph:
10501049
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1082,15 +1081,8 @@ def forward(
10821081
decode_k_nope = None
10831082
assert attn_metadata.decode is not None
10841083
if self.running_in_graph:
1085-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1086-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1087-
dtype=decode_hs_or_q_c.dtype)
1088-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1089-
dtype=decode_hs_or_q_c.dtype)
1090-
cos = cos[attn_metadata.decode.input_positions]
1091-
sin = sin[attn_metadata.decode.input_positions]
1092-
cos = cos[:, None, None, :]
1093-
sin = sin[:, None, None, :]
1084+
cos = attn_metadata.decode.cos
1085+
sin = attn_metadata.decode.sin
10941086
# Without explicitly controlling the order, IndexByTensor operations
10951087
# would be placed after `matmul W_KV_T` hindering the overlapping of
10961088
# KvRmsNormRopeCache and SingleRope.
@@ -1125,15 +1117,8 @@ def forward(
11251117
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11261118
if self.torchair_graph_enabled:
11271119
num_tokens = prefill_hs_or_q_c.shape[0]
1128-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1129-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1130-
dtype=prefill_q_pe.dtype)
1131-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1132-
dtype=prefill_q_pe.dtype)
1133-
cos = cos[attn_metadata.prefill.input_positions]
1134-
sin = sin[attn_metadata.prefill.input_positions]
1135-
cos = cos[:, None, None, :]
1136-
sin = sin[:, None, None, :]
1120+
cos = attn_metadata.prefill.cos
1121+
sin = attn_metadata.prefill.sin
11371122

11381123
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11391124
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/models/deepseek_v2.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
676676
self.make_empty_intermediate_tensors = (
677677
make_empty_intermediate_tensors_factory(
678678
["hidden_states", "residual"], config.hidden_size))
679+
ascend_config = get_ascend_config()
680+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
681+
self.cos_cached = self.layers[
682+
self.start_layer].self_attn.rotary_emb.cos_cached
683+
self.sin_cached = self.layers[
684+
self.start_layer].self_attn.rotary_emb.sin_cached
679685

680686
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
681687
return self.embed_tokens(input_ids)
@@ -700,6 +706,24 @@ def forward(
700706
hidden_states = intermediate_tensors["hidden_states"]
701707
residual = intermediate_tensors["residual"]
702708

709+
forward_context = get_forward_context()
710+
# Index select sin/cos for rope here.
711+
if attn_metadata is not None:
712+
if attn_metadata.num_decodes > 0 and forward_context.running_in_graph:
713+
cos = self.cos_cached.to(dtype=hidden_states.dtype)
714+
sin = self.sin_cached.to(dtype=hidden_states.dtype)
715+
cos = cos[attn_metadata.decode.input_positions]
716+
sin = sin[attn_metadata.decode.input_positions]
717+
attn_metadata.decode.cos = cos[:, None, None, :]
718+
attn_metadata.decode.sin = sin[:, None, None, :]
719+
if attn_metadata.num_prefills > 0 and self.torchair_graph_enabled:
720+
cos = self.cos_cached.to(dtype=hidden_states.dtype)
721+
sin = self.sin_cached.to(dtype=hidden_states.dtype)
722+
cos = cos[attn_metadata.prefill.input_positions]
723+
sin = sin[attn_metadata.prefill.input_positions]
724+
attn_metadata.prefill.cos = cos[:, None, None, :]
725+
attn_metadata.prefill.sin = sin[:, None, None, :]
726+
703727
for i in range(self.start_layer, self.end_layer):
704728
layer = self.layers[i]
705729
hidden_states, residual = layer(

0 commit comments

Comments
 (0)