Skip to content

Commit 80f3214

Browse files
committed
avoid performing index selection of sin/cos cache every layer
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e1d282d commit 80f3214

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.forward_context import get_forward_context
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
@@ -93,6 +94,8 @@ class ChunkedContextMetadata:
9394
max_query_len: int
9495
max_seq_lens: int
9596
chunked_context: Optional[ChunkedContextMetadata] = None
97+
sin: torch.Tensor = None
98+
cos: torch.Tensor = None
9699

97100

98101
@dataclass
@@ -106,6 +109,8 @@ class AscendMLADecodeMetadata:
106109
seq_lens_list: list[int]
107110
actual_seq_q_lens: Optional[list[int]] = None
108111
attn_mask: Optional[torch.Tensor] = None
112+
sin: torch.Tensor = None
113+
cos: torch.Tensor = None
109114

110115

111116
@dataclass
@@ -217,6 +222,9 @@ def __init__(self,
217222
)
218223
ascend_config = get_ascend_config()
219224
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
225+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
226+
self.cos_cache = None
227+
self.sin_cache = None
220228

221229
def reorder_batch(self, input_batch: "InputBatch",
222230
scheduler_output: "SchedulerOutput") -> bool:
@@ -348,6 +356,18 @@ def build_torchair_graph_dummy(
348356
else:
349357
attn_state = AscendAttentionState.DecodeOnly
350358
num_decode_tokens = 1
359+
sin = torch.ones(num_reqs,
360+
1,
361+
1,
362+
self.rope_dim,
363+
dtype=self.runner.dtype,
364+
device=device)
365+
cos = torch.ones(num_reqs,
366+
1,
367+
1,
368+
self.rope_dim,
369+
dtype=self.runner.dtype,
370+
device=device)
351371
decode_metadata = AscendMLADecodeMetadata(
352372
input_positions=input_positions,
353373
block_table=block_table,
@@ -356,7 +376,8 @@ def build_torchair_graph_dummy(
356376
max_seq_lens=1,
357377
attn_mask=self.runner.spec_attn_mask,
358378
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
359-
)
379+
sin=sin,
380+
cos=cos)
360381
return self.metadata_cls( # type: ignore
361382
num_input_tokens=num_actual_tokens,
362383
num_actual_tokens=num_actual_tokens,
@@ -408,6 +429,14 @@ def build(
408429
max_query_len = query_lens.max().item()
409430
max_seq_lens = seq_lens.max().item()
410431
query_start_loc = common_attn_metadata.query_start_loc
432+
if self.cos_cache is None:
433+
self.cos_cache = self.runner.get_model(
434+
).model.layers[0].self_attn.rotary_emb.cos_cached
435+
self.sin_cache = self.runner.get_model(
436+
).model.layers[0].self_attn.rotary_emb.sin_cached
437+
if self.cos_cache.dtype != self.runner.dtype:
438+
self.cos_cache = self.cos_cache.to(self.runner.dtype)
439+
self.sin_cache = self.sin_cache.to(self.runner.dtype)
411440

412441
prefill_metadata = None
413442
chunked_context_metadata = None
@@ -454,18 +483,24 @@ def build(
454483
chunk_seq_lens=chunk_seq_lens,
455484
workspace=self.chunked_prefill_workspace,
456485
)
457-
486+
prefill_input_positions = input_positions[tokens_start:]
487+
cos = self.cos_cache[prefill_input_positions].unsqueeze(
488+
1).unsqueeze(2)
489+
sin = self.sin_cache[prefill_input_positions].unsqueeze(
490+
1).unsqueeze(2)
458491
prefill_metadata = AscendMLAPrefillMetadata(
459492
attn_mask=self.runner.attn_mask,
460493
query_lens=query_lens[tokens_start:],
461494
seq_lens=seq_lens,
462495
context_lens=seq_lens[tokens_start:],
463-
input_positions=input_positions[tokens_start:],
496+
input_positions=prefill_input_positions,
464497
block_table=block_table[reqs_start:, ...],
465498
max_query_len=max_query_len,
466499
max_seq_lens=max_seq_lens,
467500
query_start_loc=prefill_query_start_loc,
468501
chunked_context=chunked_context_metadata,
502+
sin=sin,
503+
cos=cos,
469504
)
470505

471506
decode_metadata = None
@@ -510,8 +545,11 @@ def build(
510545
actual_seq_q_lens = query_start_loc[1:].tolist(
511546
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
512547
num_reqs_pad_size]
548+
cos = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(2)
549+
sin = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(2)
513550
else:
514551
seq_lens_list = seq_lens.tolist()
552+
cos, sin = None, None
515553

516554
decode_metadata = AscendMLADecodeMetadata(
517555
input_positions=input_positions,
@@ -521,7 +559,8 @@ def build(
521559
max_seq_lens=max_seq_lens,
522560
attn_mask=self.runner.spec_attn_mask,
523561
actual_seq_q_lens=actual_seq_q_lens,
524-
)
562+
sin=sin,
563+
cos=cos)
525564

526565
return self.metadata_cls( # type: ignore
527566
num_actual_tokens=num_actual_tokens,
@@ -1113,15 +1152,8 @@ def forward(
11131152
decode_k_nope = None
11141153
assert attn_metadata.decode is not None
11151154
if self.running_in_graph:
1116-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1117-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1118-
dtype=decode_hs_or_q_c.dtype)
1119-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1120-
dtype=decode_hs_or_q_c.dtype)
1121-
cos = cos[attn_metadata.decode.input_positions]
1122-
sin = sin[attn_metadata.decode.input_positions]
1123-
cos = cos[:, None, None, :]
1124-
sin = sin[:, None, None, :]
1155+
cos = attn_metadata.decode.cos
1156+
sin = attn_metadata.decode.sin
11251157
# Without explicitly controlling the order, IndexByTensor operations
11261158
# would be placed after `matmul W_KV_T` hindering the overlapping of
11271159
# KvRmsNormRopeCache and SingleRope.
@@ -1156,15 +1188,8 @@ def forward(
11561188
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11571189
if self.torchair_graph_enabled:
11581190
num_tokens = prefill_hs_or_q_c.shape[0]
1159-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1160-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1161-
dtype=prefill_q_pe.dtype)
1162-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1163-
dtype=prefill_q_pe.dtype)
1164-
cos = cos[attn_metadata.prefill.input_positions]
1165-
sin = sin[attn_metadata.prefill.input_positions]
1166-
cos = cos[:, None, None, :]
1167-
sin = sin[:, None, None, :]
1191+
cos = attn_metadata.prefill.cos
1192+
sin = attn_metadata.prefill.sin
11681193

11691194
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11701195
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,8 @@ def _dummy_run(
16471647
attn_metadata.decode.block_table)
16481648
torch._dynamo.mark_static(
16491649
attn_metadata.decode.input_positions)
1650+
torch._dynamo.mark_static(attn_metadata.decode.sin)
1651+
torch._dynamo.mark_static(attn_metadata.decode.cos)
16501652
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16511653
for kv in self.kv_caches:
16521654
assert isinstance(

0 commit comments

Comments
 (0)