Skip to content

Commit 10ee2e7

Browse files
authored
[BugFix] Fix a bug of running chunked-prefill with torchair. (#1378)
This PR fixes a bug of running chunked-prefill with torchair. Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 3191183 commit 10ee2e7

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,11 +1054,10 @@ def forward(
10541054
]
10551055
num_actual_toks = attn_metadata.num_actual_tokens
10561056
if k_pe is None and not self.running_in_graph:
1057-
if not self.torchair_graph_enabled:
1058-
kv_c, k_pe = self.kv_a_proj_with_mqa(
1059-
hidden_states_or_kv_c_normed)[0].split(
1060-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1061-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1057+
kv_c, k_pe = self.kv_a_proj_with_mqa(
1058+
hidden_states_or_kv_c_normed)[0].split(
1059+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1060+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
10621061
else:
10631062
kv_c_normed = hidden_states_or_kv_c_normed
10641063
assert attn_metadata.num_decodes is not None and \
@@ -1077,12 +1076,13 @@ def forward(
10771076
if not self.running_in_graph:
10781077
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
10791078
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1080-
if not self.torchair_graph_enabled:
1081-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1082-
k_pe = k_pe[:num_actual_toks, ...]
1083-
k_pe = k_pe.unsqueeze(1)
1084-
decode_k_pe = k_pe[:num_decode_tokens]
1085-
prefill_k_pe = k_pe[num_decode_tokens:]
1079+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1080+
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
1081+
# if not self.torchair_graph_enabled:
1082+
k_pe = k_pe[:num_actual_toks, ...]
1083+
k_pe = k_pe.unsqueeze(1)
1084+
decode_k_pe = k_pe[:num_decode_tokens]
1085+
prefill_k_pe = k_pe[num_decode_tokens:]
10861086
else:
10871087
decode_hs_or_q_c = hidden_states_or_q_c
10881088
if has_decode:
@@ -1146,11 +1146,11 @@ def forward(
11461146

11471147
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11481148
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1149-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1150-
attn_metadata.slot_mapping)
1149+
prefill_hs, cos, sin, kv_cache,
1150+
attn_metadata.slot_mapping[num_decode_tokens:])
11511151

11521152
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1153-
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
1153+
prefill_k_c_normed = prefill_k_nope
11541154
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
11551155
-1)
11561156
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)

0 commit comments

Comments
 (0)