diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 98f0a3389c..8712cf396f 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -93,6 +93,8 @@ class ChunkedContextMetadata: max_query_len: int max_seq_lens: int chunked_context: Optional[ChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -106,6 +108,8 @@ class AscendMLADecodeMetadata: seq_lens_list: list[int] actual_seq_q_lens: Optional[list[int]] = None attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -217,6 +221,9 @@ def __init__(self, ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -348,6 +355,18 @@ def build_torchair_graph_dummy( else: attn_state = AscendAttentionState.DecodeOnly num_decode_tokens = 1 + sin = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) + cos = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -356,7 +375,8 @@ def build_torchair_graph_dummy( max_seq_lens=1, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], - ) + sin=sin, + cos=cos) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -408,6 +428,16 @@ def build( max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() query_start_loc = common_attn_metadata.query_start_loc + if self.cos_cache is None: + self.cos_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.cos_cached + self.sin_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.runner.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.runner.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.runner.dtype) # type: ignore prefill_metadata = None chunked_context_metadata = None @@ -454,18 +484,26 @@ def build( chunk_seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, ) - + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, query_lens=query_lens[tokens_start:], seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], - input_positions=input_positions[tokens_start:], + input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, max_seq_lens=max_seq_lens, query_start_loc=prefill_query_start_loc, chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, ) decode_metadata = None @@ -510,8 +548,15 @@ def build( actual_seq_q_lens = query_start_loc[1:].tolist( ) + self.runner.actual_seq_q_lens[num_reqs:num_reqs + num_reqs_pad_size] + cos = self.cos_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) else: seq_lens_list = seq_lens.tolist() + cos, sin = None, None decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -521,7 +566,8 @@ def build( max_seq_lens=max_seq_lens, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, - ) + sin=sin, + cos=cos) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -1113,15 +1159,8 @@ def forward( decode_k_nope = None assert attn_metadata.decode is not None if self.running_in_graph: - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - cos = cos[attn_metadata.decode.input_positions] - sin = sin[attn_metadata.decode.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin # Without explicitly controlling the order, IndexByTensor operations # would be placed after `matmul W_KV_T` hindering the overlapping of # KvRmsNormRopeCache and SingleRope. @@ -1156,15 +1195,8 @@ def forward( prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - cos = cos[attn_metadata.prefill.input_positions] - sin = sin[attn_metadata.prefill.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e0ab79be45..fe57ed240e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1647,6 +1647,8 @@ def _dummy_run( attn_metadata.decode.block_table) torch._dynamo.mark_static( attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance(