Skip to content

Commit ab7f407

Browse files
committed
avoid performing index selection of sin/cos cache every layer
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e1d282d commit ab7f407

File tree

2 files changed

+56
-22
lines changed

2 files changed

+56
-22
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class ChunkedContextMetadata:
9393
max_query_len: int
9494
max_seq_lens: int
9595
chunked_context: Optional[ChunkedContextMetadata] = None
96+
sin: torch.Tensor = None
97+
cos: torch.Tensor = None
9698

9799

98100
@dataclass
@@ -106,6 +108,8 @@ class AscendMLADecodeMetadata:
106108
seq_lens_list: list[int]
107109
actual_seq_q_lens: Optional[list[int]] = None
108110
attn_mask: Optional[torch.Tensor] = None
111+
sin: torch.Tensor = None
112+
cos: torch.Tensor = None
109113

110114

111115
@dataclass
@@ -217,6 +221,9 @@ def __init__(self,
217221
)
218222
ascend_config = get_ascend_config()
219223
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
224+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
225+
self.cos_cache = None
226+
self.sin_cache = None
220227

221228
def reorder_batch(self, input_batch: "InputBatch",
222229
scheduler_output: "SchedulerOutput") -> bool:
@@ -348,6 +355,18 @@ def build_torchair_graph_dummy(
348355
else:
349356
attn_state = AscendAttentionState.DecodeOnly
350357
num_decode_tokens = 1
358+
sin = torch.ones(num_reqs,
359+
1,
360+
1,
361+
self.rope_dim,
362+
dtype=self.runner.dtype,
363+
device=device)
364+
cos = torch.ones(num_reqs,
365+
1,
366+
1,
367+
self.rope_dim,
368+
dtype=self.runner.dtype,
369+
device=device)
351370
decode_metadata = AscendMLADecodeMetadata(
352371
input_positions=input_positions,
353372
block_table=block_table,
@@ -356,7 +375,8 @@ def build_torchair_graph_dummy(
356375
max_seq_lens=1,
357376
attn_mask=self.runner.spec_attn_mask,
358377
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
359-
)
378+
sin=sin,
379+
cos=cos)
360380
return self.metadata_cls( # type: ignore
361381
num_input_tokens=num_actual_tokens,
362382
num_actual_tokens=num_actual_tokens,
@@ -408,6 +428,16 @@ def build(
408428
max_query_len = query_lens.max().item()
409429
max_seq_lens = seq_lens.max().item()
410430
query_start_loc = common_attn_metadata.query_start_loc
431+
if self.cos_cache is None:
432+
self.cos_cache = self.runner.get_model(
433+
).model.layers[0].self_attn.rotary_emb.cos_cached
434+
self.sin_cache = self.runner.get_model(
435+
).model.layers[0].self_attn.rotary_emb.sin_cached
436+
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
437+
self.cos_cache = self.cos_cache.to( # type: ignore
438+
self.runner.dtype) # type: ignore
439+
self.sin_cache = self.sin_cache.to( # type: ignore
440+
self.runner.dtype) # type: ignore
411441

412442
prefill_metadata = None
413443
chunked_context_metadata = None
@@ -454,18 +484,26 @@ def build(
454484
chunk_seq_lens=chunk_seq_lens,
455485
workspace=self.chunked_prefill_workspace,
456486
)
457-
487+
prefill_input_positions = input_positions[tokens_start:]
488+
cos = self.cos_cache[
489+
prefill_input_positions].unsqueeze( # type: ignore
490+
1).unsqueeze(2)
491+
sin = self.sin_cache[
492+
prefill_input_positions].unsqueeze( # type: ignore
493+
1).unsqueeze(2)
458494
prefill_metadata = AscendMLAPrefillMetadata(
459495
attn_mask=self.runner.attn_mask,
460496
query_lens=query_lens[tokens_start:],
461497
seq_lens=seq_lens,
462498
context_lens=seq_lens[tokens_start:],
463-
input_positions=input_positions[tokens_start:],
499+
input_positions=prefill_input_positions,
464500
block_table=block_table[reqs_start:, ...],
465501
max_query_len=max_query_len,
466502
max_seq_lens=max_seq_lens,
467503
query_start_loc=prefill_query_start_loc,
468504
chunked_context=chunked_context_metadata,
505+
sin=sin,
506+
cos=cos,
469507
)
470508

471509
decode_metadata = None
@@ -510,8 +548,15 @@ def build(
510548
actual_seq_q_lens = query_start_loc[1:].tolist(
511549
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
512550
num_reqs_pad_size]
551+
cos = self.cos_cache[
552+
input_positions].unsqueeze( # type: ignore
553+
1).unsqueeze(2)
554+
sin = self.sin_cache[
555+
input_positions].unsqueeze( # type: ignore
556+
1).unsqueeze(2)
513557
else:
514558
seq_lens_list = seq_lens.tolist()
559+
cos, sin = None, None
515560

516561
decode_metadata = AscendMLADecodeMetadata(
517562
input_positions=input_positions,
@@ -521,7 +566,8 @@ def build(
521566
max_seq_lens=max_seq_lens,
522567
attn_mask=self.runner.spec_attn_mask,
523568
actual_seq_q_lens=actual_seq_q_lens,
524-
)
569+
sin=sin,
570+
cos=cos)
525571

526572
return self.metadata_cls( # type: ignore
527573
num_actual_tokens=num_actual_tokens,
@@ -1113,15 +1159,8 @@ def forward(
11131159
decode_k_nope = None
11141160
assert attn_metadata.decode is not None
11151161
if self.running_in_graph:
1116-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1117-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1118-
dtype=decode_hs_or_q_c.dtype)
1119-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1120-
dtype=decode_hs_or_q_c.dtype)
1121-
cos = cos[attn_metadata.decode.input_positions]
1122-
sin = sin[attn_metadata.decode.input_positions]
1123-
cos = cos[:, None, None, :]
1124-
sin = sin[:, None, None, :]
1162+
cos = attn_metadata.decode.cos
1163+
sin = attn_metadata.decode.sin
11251164
# Without explicitly controlling the order, IndexByTensor operations
11261165
# would be placed after `matmul W_KV_T` hindering the overlapping of
11271166
# KvRmsNormRopeCache and SingleRope.
@@ -1156,15 +1195,8 @@ def forward(
11561195
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11571196
if self.torchair_graph_enabled:
11581197
num_tokens = prefill_hs_or_q_c.shape[0]
1159-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1160-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1161-
dtype=prefill_q_pe.dtype)
1162-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1163-
dtype=prefill_q_pe.dtype)
1164-
cos = cos[attn_metadata.prefill.input_positions]
1165-
sin = sin[attn_metadata.prefill.input_positions]
1166-
cos = cos[:, None, None, :]
1167-
sin = sin[:, None, None, :]
1198+
cos = attn_metadata.prefill.cos
1199+
sin = attn_metadata.prefill.sin
11681200

11691201
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11701202
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,8 @@ def _dummy_run(
16471647
attn_metadata.decode.block_table)
16481648
torch._dynamo.mark_static(
16491649
attn_metadata.decode.input_positions)
1650+
torch._dynamo.mark_static(attn_metadata.decode.sin)
1651+
torch._dynamo.mark_static(attn_metadata.decode.cos)
16501652
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16511653
for kv in self.kv_caches:
16521654
assert isinstance(

0 commit comments

Comments
 (0)