File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -54,6 +54,15 @@ def merge_attn_states_kernel(
54
54
55
55
p_lse = tl .load (prefix_lse + head_idx * num_tokens + token_idx )
56
56
s_lse = tl .load (suffix_lse + head_idx * num_tokens + token_idx )
57
+
58
+ # FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
59
+ # arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
60
+ # If we see an inf assume FA2 and convert inf to -inf for consistency
61
+ # and correctness. Inf generally doesn't make sense in this context outside
62
+ # of undefined-behavior/FA2-case, so I think this a safe assumption.
63
+ p_lse = float ('-inf' ) if p_lse == float ('inf' ) else p_lse
64
+ s_lse = float ('-inf' ) if s_lse == float ('inf' ) else s_lse
65
+
57
66
max_lse = tl .maximum (p_lse , s_lse )
58
67
p_lse = p_lse - max_lse
59
68
s_lse = s_lse - max_lse
You can’t perform that action at this time.
0 commit comments