Skip to content

fix cudnn sdpa invalid seqlen for unused segments #30023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,17 +525,13 @@ def _cu_offset(offsets, max_seq):
B, T, N, H = query.shape
_, S, _, _ = key.shape

q_seqlen = _shift_to_left(q_seqlen, -1)
kv_seqlen = _shift_to_left(kv_seqlen, -1)
q_seqlen = _shift_to_left(q_seqlen, 0)
kv_seqlen = _shift_to_left(kv_seqlen, 0)

q_offsets = _cu_offset(q_offsets, T)
kv_offsets = _cu_offset(kv_offsets, S)
q_offsets = _shift_to_left(q_offsets, -1)
kv_offsets = _shift_to_left(kv_offsets, -1)

# mark any invalid entries as maximum offset
q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets)
kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets)
q_offsets = _shift_to_left(q_offsets, B * T)
kv_offsets = _shift_to_left(kv_offsets, B * S)

# multiply by stride_per_token to get correct offsets
# do it here because real stride changes after sharding
Expand Down
Loading