Skip to content

Commit ea3dc31

Browse files
authored
[0.9.1][Perf] Optimize the number of rope-related index selections in deepseek. (#1614)
This PR avoids performing index selection of sin/cos cache every layer in deepseek. Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 5559443 commit ea3dc31

File tree

2 files changed

+56
-22
lines changed

2 files changed

+56
-22
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class ChunkedContextMetadata:
8181
max_query_len: int
8282
max_seq_lens: int
8383
chunked_context: Optional[ChunkedContextMetadata] = None
84+
sin: torch.Tensor = None
85+
cos: torch.Tensor = None
8486

8587

8688
@dataclass
@@ -94,6 +96,8 @@ class AscendMLADecodeMetadata:
9496
seq_lens_list: list[int]
9597
actual_seq_q_lens: Optional[list[int]] = None
9698
attn_mask: Optional[torch.Tensor] = None
99+
sin: torch.Tensor = None
100+
cos: torch.Tensor = None
97101

98102

99103
@dataclass
@@ -205,6 +209,9 @@ def __init__(self,
205209
)
206210
ascend_config = get_ascend_config()
207211
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
212+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
213+
self.cos_cache = None
214+
self.sin_cache = None
208215

209216
def reorder_batch(self, input_batch: "InputBatch",
210217
scheduler_output: "SchedulerOutput") -> bool:
@@ -336,6 +343,18 @@ def build_torchair_graph_dummy(
336343
else:
337344
attn_state = AscendAttentionState.DecodeOnly
338345
num_decode_tokens = 1
346+
sin = torch.ones(num_reqs,
347+
1,
348+
1,
349+
self.rope_dim,
350+
dtype=self.runner.dtype,
351+
device=device)
352+
cos = torch.ones(num_reqs,
353+
1,
354+
1,
355+
self.rope_dim,
356+
dtype=self.runner.dtype,
357+
device=device)
339358
decode_metadata = AscendMLADecodeMetadata(
340359
input_positions=input_positions,
341360
block_table=block_table,
@@ -344,7 +363,8 @@ def build_torchair_graph_dummy(
344363
max_seq_lens=1,
345364
attn_mask=self.runner.spec_attn_mask,
346365
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
347-
)
366+
sin=sin,
367+
cos=cos)
348368
return self.metadata_cls( # type: ignore
349369
num_input_tokens=num_actual_tokens,
350370
num_actual_tokens=num_actual_tokens,
@@ -396,6 +416,16 @@ def build(
396416
max_query_len = query_lens.max().item()
397417
max_seq_lens = seq_lens.max().item()
398418
query_start_loc = common_attn_metadata.query_start_loc
419+
if self.cos_cache is None:
420+
self.cos_cache = self.runner.get_model(
421+
).model.layers[0].self_attn.rotary_emb.cos_cached
422+
self.sin_cache = self.runner.get_model(
423+
).model.layers[0].self_attn.rotary_emb.sin_cached
424+
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
425+
self.cos_cache = self.cos_cache.to( # type: ignore
426+
self.runner.dtype) # type: ignore
427+
self.sin_cache = self.sin_cache.to( # type: ignore
428+
self.runner.dtype) # type: ignore
399429

400430
prefill_metadata = None
401431
chunked_context_metadata = None
@@ -442,18 +472,26 @@ def build(
442472
chunk_seq_lens=chunk_seq_lens,
443473
workspace=self.chunked_prefill_workspace,
444474
)
445-
475+
prefill_input_positions = input_positions[tokens_start:]
476+
cos = self.cos_cache[
477+
prefill_input_positions].unsqueeze( # type: ignore
478+
1).unsqueeze(2)
479+
sin = self.sin_cache[
480+
prefill_input_positions].unsqueeze( # type: ignore
481+
1).unsqueeze(2)
446482
prefill_metadata = AscendMLAPrefillMetadata(
447483
attn_mask=self.runner.attn_mask,
448484
query_lens=query_lens[tokens_start:],
449485
seq_lens=seq_lens,
450486
context_lens=seq_lens[tokens_start:],
451-
input_positions=input_positions[tokens_start:],
487+
input_positions=prefill_input_positions,
452488
block_table=block_table[reqs_start:, ...],
453489
max_query_len=max_query_len,
454490
max_seq_lens=max_seq_lens,
455491
query_start_loc=prefill_query_start_loc,
456492
chunked_context=chunked_context_metadata,
493+
sin=sin,
494+
cos=cos,
457495
)
458496

459497
decode_metadata = None
@@ -498,8 +536,15 @@ def build(
498536
actual_seq_q_lens = query_start_loc[1:].tolist(
499537
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
500538
num_reqs_pad_size]
539+
cos = self.cos_cache[
540+
input_positions].unsqueeze( # type: ignore
541+
1).unsqueeze(2)
542+
sin = self.sin_cache[
543+
input_positions].unsqueeze( # type: ignore
544+
1).unsqueeze(2)
501545
else:
502546
seq_lens_list = seq_lens.tolist()
547+
cos, sin = None, None
503548

504549
decode_metadata = AscendMLADecodeMetadata(
505550
input_positions=input_positions,
@@ -509,7 +554,8 @@ def build(
509554
max_seq_lens=max_seq_lens,
510555
attn_mask=self.runner.spec_attn_mask,
511556
actual_seq_q_lens=actual_seq_q_lens,
512-
)
557+
sin=sin,
558+
cos=cos)
513559

514560
return self.metadata_cls( # type: ignore
515561
num_actual_tokens=num_actual_tokens,
@@ -1101,15 +1147,8 @@ def forward(
11011147
decode_k_nope = None
11021148
assert attn_metadata.decode is not None
11031149
if self.running_in_graph:
1104-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1105-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1106-
dtype=decode_hs_or_q_c.dtype)
1107-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1108-
dtype=decode_hs_or_q_c.dtype)
1109-
cos = cos[attn_metadata.decode.input_positions]
1110-
sin = sin[attn_metadata.decode.input_positions]
1111-
cos = cos[:, None, None, :]
1112-
sin = sin[:, None, None, :]
1150+
cos = attn_metadata.decode.cos
1151+
sin = attn_metadata.decode.sin
11131152
# Without explicitly controlling the order, IndexByTensor operations
11141153
# would be placed after `matmul W_KV_T` hindering the overlapping of
11151154
# KvRmsNormRopeCache and SingleRope.
@@ -1144,15 +1183,8 @@ def forward(
11441183
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11451184
if self.torchair_graph_enabled:
11461185
num_tokens = prefill_hs_or_q_c.shape[0]
1147-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1148-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1149-
dtype=prefill_q_pe.dtype)
1150-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1151-
dtype=prefill_q_pe.dtype)
1152-
cos = cos[attn_metadata.prefill.input_positions]
1153-
sin = sin[attn_metadata.prefill.input_positions]
1154-
cos = cos[:, None, None, :]
1155-
sin = sin[:, None, None, :]
1186+
cos = attn_metadata.prefill.cos
1187+
sin = attn_metadata.prefill.sin
11561188

11571189
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11581190
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,8 @@ def _dummy_run(
16661666
attn_metadata.decode.block_table)
16671667
torch._dynamo.mark_static(
16681668
attn_metadata.decode.input_positions)
1669+
torch._dynamo.mark_static(attn_metadata.decode.sin)
1670+
torch._dynamo.mark_static(attn_metadata.decode.cos)
16691671
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16701672
for kv in self.kv_caches:
16711673
assert isinstance(

0 commit comments

Comments
 (0)