Skip to content

[0.9.1][Perf] Optimize the number of rope-related index selections in deepseek. #1614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class ChunkedContextMetadata:
max_query_len: int
max_seq_lens: int
chunked_context: Optional[ChunkedContextMetadata] = None
sin: torch.Tensor = None
cos: torch.Tensor = None


@dataclass
Expand All @@ -106,6 +108,8 @@ class AscendMLADecodeMetadata:
seq_lens_list: list[int]
actual_seq_q_lens: Optional[list[int]] = None
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None


@dataclass
Expand Down Expand Up @@ -217,6 +221,9 @@ def __init__(self,
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Expand Down Expand Up @@ -348,6 +355,18 @@ def build_torchair_graph_dummy(
else:
attn_state = AscendAttentionState.DecodeOnly
num_decode_tokens = 1
sin = torch.ones(num_reqs,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
device=device)
cos = torch.ones(num_reqs,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
Expand All @@ -356,7 +375,8 @@ def build_torchair_graph_dummy(
max_seq_lens=1,
attn_mask=self.runner.spec_attn_mask,
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
)
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -408,6 +428,16 @@ def build(
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
query_start_loc = common_attn_metadata.query_start_loc
if self.cos_cache is None:
self.cos_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.cos_cached
self.sin_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.runner.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.runner.dtype) # type: ignore

prefill_metadata = None
chunked_context_metadata = None
Expand Down Expand Up @@ -454,18 +484,26 @@ def build(
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)

prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:],
seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:],
input_positions=input_positions[tokens_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)

decode_metadata = None
Expand Down Expand Up @@ -510,8 +548,15 @@ def build(
actual_seq_q_lens = query_start_loc[1:].tolist(
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
num_reqs_pad_size]
cos = self.cos_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
else:
seq_lens_list = seq_lens.tolist()
cos, sin = None, None

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
Expand All @@ -521,7 +566,8 @@ def build(
max_seq_lens=max_seq_lens,
attn_mask=self.runner.spec_attn_mask,
actual_seq_q_lens=actual_seq_q_lens,
)
sin=sin,
cos=cos)

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -1113,15 +1159,8 @@ def forward(
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
Expand Down Expand Up @@ -1156,15 +1195,8 @@ def forward(
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin

prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,8 @@ def _dummy_run(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
Expand Down