|
9 | 9 | MLAAttentionImpl)
|
10 | 10 | from vllm.attention.backends.utils import PAD_SLOT_ID
|
11 | 11 | from vllm.config import get_current_vllm_config
|
| 12 | +from vllm.forward_context import get_forward_context |
12 | 13 | from vllm.model_executor.layers.linear import (LinearBase,
|
13 | 14 | UnquantizedLinearMethod)
|
14 | 15 | from vllm.utils import cdiv, round_down
|
@@ -1042,9 +1043,7 @@ def forward(
|
1042 | 1043 | if attn_metadata is None:
|
1043 | 1044 | # Profiling run.
|
1044 | 1045 | return output
|
1045 |
| - self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ |
1046 |
| - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding |
1047 |
| - ] |
| 1046 | + self.running_in_graph = get_forward_context().running_in_graph |
1048 | 1047 | num_actual_toks = attn_metadata.num_actual_tokens
|
1049 | 1048 | if k_pe is None and not self.running_in_graph:
|
1050 | 1049 | kv_c, k_pe = self.kv_a_proj_with_mqa(
|
@@ -1082,15 +1081,8 @@ def forward(
|
1082 | 1081 | decode_k_nope = None
|
1083 | 1082 | assert attn_metadata.decode is not None
|
1084 | 1083 | if self.running_in_graph:
|
1085 |
| - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor |
1086 |
| - cos = self.rotary_emb.cos_cached[:seq_len].to( |
1087 |
| - dtype=decode_hs_or_q_c.dtype) |
1088 |
| - sin = self.rotary_emb.sin_cached[:seq_len].to( |
1089 |
| - dtype=decode_hs_or_q_c.dtype) |
1090 |
| - cos = cos[attn_metadata.decode.input_positions] |
1091 |
| - sin = sin[attn_metadata.decode.input_positions] |
1092 |
| - cos = cos[:, None, None, :] |
1093 |
| - sin = sin[:, None, None, :] |
| 1084 | + cos = attn_metadata.decode.cos |
| 1085 | + sin = attn_metadata.decode.sin |
1094 | 1086 | # Without explicitly controlling the order, IndexByTensor operations
|
1095 | 1087 | # would be placed after `matmul W_KV_T` hindering the overlapping of
|
1096 | 1088 | # KvRmsNormRopeCache and SingleRope.
|
@@ -1125,15 +1117,8 @@ def forward(
|
1125 | 1117 | prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
1126 | 1118 | if self.torchair_graph_enabled:
|
1127 | 1119 | num_tokens = prefill_hs_or_q_c.shape[0]
|
1128 |
| - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor |
1129 |
| - cos = self.rotary_emb.cos_cached[:seq_len].to( |
1130 |
| - dtype=prefill_q_pe.dtype) |
1131 |
| - sin = self.rotary_emb.sin_cached[:seq_len].to( |
1132 |
| - dtype=prefill_q_pe.dtype) |
1133 |
| - cos = cos[attn_metadata.prefill.input_positions] |
1134 |
| - sin = sin[attn_metadata.prefill.input_positions] |
1135 |
| - cos = cos[:, None, None, :] |
1136 |
| - sin = sin[:, None, None, :] |
| 1120 | + cos = attn_metadata.prefill.cos |
| 1121 | + sin = attn_metadata.prefill.sin |
1137 | 1122 |
|
1138 | 1123 | prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
1139 | 1124 | prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
|
0 commit comments