Skip to content

Commit cae98da

Browse files
LucasWilkinsonlk-chen
authored andcommitted
[BugFix] Fix nightly MLA failure (FA2 + MLA chunked prefill, i.e. V1, producing bad results) (vllm-project#15492)
Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com>
1 parent 03372ac commit cae98da

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def merge_attn_states_kernel(
5454

5555
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
5656
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
57+
58+
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
59+
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
60+
# If we see an inf assume FA2 and convert inf to -inf for consistency
61+
# and correctness. Inf generally doesn't make sense in this context outside
62+
# of undefined-behavior/FA2-case, so I think this a safe assumption.
63+
p_lse = float('-inf') if p_lse == float('inf') else p_lse
64+
s_lse = float('-inf') if s_lse == float('inf') else s_lse
65+
5766
max_lse = tl.maximum(p_lse, s_lse)
5867
p_lse = p_lse - max_lse
5968
s_lse = s_lse - max_lse

0 commit comments

Comments
 (0)