Skip to content

Commit 9621b4d

Browse files
committed
avoid performing index selection of sin/cos cache every layer
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e878d56 commit 9621b4d

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.forward_context import get_forward_context
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
@@ -93,6 +94,8 @@ class ChunkedContextMetadata:
9394
max_query_len: int
9495
max_seq_lens: int
9596
chunked_context: Optional[ChunkedContextMetadata] = None
97+
sin: torch.Tensor = None
98+
cos: torch.Tensor = None
9699

97100

98101
@dataclass
@@ -105,6 +108,8 @@ class AscendMLADecodeMetadata:
105108
max_seq_lens: int
106109
seq_lens_list: list[int]
107110
attn_mask: Optional[torch.Tensor] = None
111+
sin: torch.Tensor = None
112+
cos: torch.Tensor = None
108113

109114

110115
@dataclass
@@ -215,6 +220,9 @@ def __init__(self,
215220
)
216221
ascend_config = get_ascend_config()
217222
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
223+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
224+
self.cos_cache = None
225+
self.sin_cache = None
218226

219227
def reorder_batch(self, input_batch: "InputBatch",
220228
scheduler_output: "SchedulerOutput") -> bool:
@@ -333,13 +341,21 @@ def build_torchair_graph_dummy(
333341
-1,
334342
dtype=torch.int32,
335343
device=device)
344+
sin = torch.ones(num_reqs, 1, 1, self.rope_dim,
345+
dtype=self.runner.dtype,
346+
device=device)
347+
cos = torch.ones(num_reqs, 1, 1, self.rope_dim,
348+
dtype=self.runner.dtype,
349+
device=device)
336350
decode_metadata = AscendMLADecodeMetadata(
337351
input_positions=input_positions,
338352
block_table=block_table,
339353
seq_lens=seq_lens,
340354
seq_lens_list=seq_lens.tolist(),
341355
max_seq_lens=1,
342-
attn_mask=self.runner.spec_attn_mask)
356+
attn_mask=self.runner.spec_attn_mask,
357+
sin=sin,
358+
cos=cos)
343359
return self.metadata_cls( # type: ignore
344360
num_input_tokens=num_actual_tokens,
345361
num_actual_tokens=num_actual_tokens,
@@ -388,6 +404,12 @@ def build(
388404
max_query_len = query_lens.max().item()
389405
max_seq_lens = seq_lens.max().item()
390406
query_start_loc = common_attn_metadata.query_start_loc
407+
if self.cos_cache is None:
408+
self.cos_cache = self.runner.get_model().model.layers[0].self_attn.rotary_emb.cos_cached
409+
self.sin_cache = self.runner.get_model().model.layers[0].self_attn.rotary_emb.sin_cached
410+
if self.cos_cache.dtype != self.runner.dtype:
411+
self.cos_cache = self.cos_cache.to(self.runner.dtype)
412+
self.sin_cache = self.sin_cache.to(self.runner.dtype)
391413

392414
prefill_metadata = None
393415
chunked_context_metadata = None
@@ -434,18 +456,22 @@ def build(
434456
chunk_seq_lens=chunk_seq_lens,
435457
workspace=self.chunked_prefill_workspace,
436458
)
437-
459+
prefill_input_positions = input_positions[tokens_start:]
460+
cos = self.cos_cache[prefill_input_positions].unsqueeze(1).unsqueeze(2)
461+
sin = self.sin_cache[prefill_input_positions].unsqueeze(1).unsqueeze(2)
438462
prefill_metadata = AscendMLAPrefillMetadata(
439463
attn_mask=self.runner.attn_mask,
440464
query_lens=query_lens[tokens_start:],
441465
seq_lens=seq_lens,
442466
context_lens=seq_lens[tokens_start:],
443-
input_positions=input_positions[tokens_start:],
467+
input_positions=prefill_input_positions,
444468
block_table=block_table[reqs_start:, ...],
445469
max_query_len=max_query_len,
446470
max_seq_lens=max_seq_lens,
447471
query_start_loc=prefill_query_start_loc,
448472
chunked_context=chunked_context_metadata,
473+
sin=sin,
474+
cos=cos,
449475
)
450476

451477
decode_metadata = None
@@ -486,14 +512,18 @@ def build(
486512
dtype=input_positions.dtype,
487513
device=input_positions.device)
488514
input_positions = torch.cat([input_positions, padding_0])
515+
cos = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(2)
516+
sin = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(2)
489517

490518
decode_metadata = AscendMLADecodeMetadata(
491519
input_positions=input_positions,
492520
block_table=block_table,
493521
seq_lens=seq_lens,
494522
seq_lens_list=seq_lens.tolist(),
495523
max_seq_lens=max_seq_lens,
496-
attn_mask=self.runner.spec_attn_mask)
524+
attn_mask=self.runner.spec_attn_mask,
525+
sin=sin,
526+
cos=cos)
497527

498528
return self.metadata_cls( # type: ignore
499529
num_actual_tokens=num_actual_tokens,
@@ -1042,9 +1072,7 @@ def forward(
10421072
if attn_metadata is None:
10431073
# Profiling run.
10441074
return output
1045-
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
1046-
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
1047-
]
1075+
self.running_in_graph = get_forward_context().running_in_graph
10481076
num_actual_toks = attn_metadata.num_actual_tokens
10491077
if k_pe is None and not self.running_in_graph:
10501078
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1082,15 +1110,8 @@ def forward(
10821110
decode_k_nope = None
10831111
assert attn_metadata.decode is not None
10841112
if self.running_in_graph:
1085-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1086-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1087-
dtype=decode_hs_or_q_c.dtype)
1088-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1089-
dtype=decode_hs_or_q_c.dtype)
1090-
cos = cos[attn_metadata.decode.input_positions]
1091-
sin = sin[attn_metadata.decode.input_positions]
1092-
cos = cos[:, None, None, :]
1093-
sin = sin[:, None, None, :]
1113+
cos = attn_metadata.decode.cos
1114+
sin = attn_metadata.decode.sin
10941115
# Without explicitly controlling the order, IndexByTensor operations
10951116
# would be placed after `matmul W_KV_T` hindering the overlapping of
10961117
# KvRmsNormRopeCache and SingleRope.
@@ -1125,15 +1146,8 @@ def forward(
11251146
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11261147
if self.torchair_graph_enabled:
11271148
num_tokens = prefill_hs_or_q_c.shape[0]
1128-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1129-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1130-
dtype=prefill_q_pe.dtype)
1131-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1132-
dtype=prefill_q_pe.dtype)
1133-
cos = cos[attn_metadata.prefill.input_positions]
1134-
sin = sin[attn_metadata.prefill.input_positions]
1135-
cos = cos[:, None, None, :]
1136-
sin = sin[:, None, None, :]
1149+
cos = attn_metadata.prefill.cos
1150+
sin = attn_metadata.prefill.sin
11371151

11381152
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11391153
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,10 @@ def _dummy_run(
16071607
attn_metadata.decode.block_table)
16081608
torch._dynamo.mark_static(
16091609
attn_metadata.decode.input_positions)
1610+
torch._dynamo.mark_static(
1611+
attn_metadata.decode.sin)
1612+
torch._dynamo.mark_static(
1613+
attn_metadata.decode.cos)
16101614
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16111615
for kv in self.kv_caches:
16121616
assert isinstance(

0 commit comments

Comments
 (0)