Skip to content

Commit f286265

Browse files
authored
[BugFix] Address PrefillCacheHit state to fix prefix cache accuracy bug (#1498)
When use AscendScheduler with prefix-cache enabled and chunk-prefill disabled, there will be accuray problem because there is no branch in mla_v1 to process this scenario. This PR fixes it. Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 5f8241c commit f286265

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,8 @@ def _forward_prefill(
758758

759759
if attn_metadata.attn_state in [
760760
AscendAttentionState.ChunkedPrefill,
761-
AscendAttentionState.SpecDecoding
761+
AscendAttentionState.SpecDecoding,
762+
AscendAttentionState.PrefillCacheHit
762763
] and not ascend_config.chunked_prefill_for_mla:
763764
attn_output_torch = torch.empty(num_tokens,
764765
self.num_heads * self.v_head_dim,
@@ -783,7 +784,8 @@ def _forward_prefill(
783784
causal=True)
784785
elif attn_metadata.attn_state in [
785786
AscendAttentionState.ChunkedPrefill,
786-
AscendAttentionState.SpecDecoding
787+
AscendAttentionState.SpecDecoding,
788+
AscendAttentionState.PrefillCacheHit
787789
]:
788790
attn_lse = torch.empty(self.num_heads,
789791
num_tokens,
@@ -835,13 +837,14 @@ def _forward_prefill(
835837
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
836838
else:
837839
raise RuntimeError(
838-
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
840+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
839841
)
840842
attn_output = attn_output.reshape(
841843
[num_tokens, self.num_heads * self.v_head_dim])
842844
if attn_metadata.attn_state in [
843845
AscendAttentionState.ChunkedPrefill,
844-
AscendAttentionState.SpecDecoding
846+
AscendAttentionState.SpecDecoding,
847+
AscendAttentionState.PrefillCacheHit
845848
] and not ascend_config.chunked_prefill_for_mla:
846849
attn_output = attn_output_torch
847850

0 commit comments

Comments
 (0)