Skip to content

Commit 8a1dbef

Browse files
Merge pull request #30023 from Cjkkkk:fix_packed_layout_invalid_entries
PiperOrigin-RevId: 781800700
2 parents c9df14e + 32ddc91 commit 8a1dbef

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -525,17 +525,13 @@ def _cu_offset(offsets, max_seq):
525525
B, T, N, H = query.shape
526526
_, S, _, _ = key.shape
527527

528-
q_seqlen = _shift_to_left(q_seqlen, -1)
529-
kv_seqlen = _shift_to_left(kv_seqlen, -1)
528+
q_seqlen = _shift_to_left(q_seqlen, 0)
529+
kv_seqlen = _shift_to_left(kv_seqlen, 0)
530530

531531
q_offsets = _cu_offset(q_offsets, T)
532532
kv_offsets = _cu_offset(kv_offsets, S)
533-
q_offsets = _shift_to_left(q_offsets, -1)
534-
kv_offsets = _shift_to_left(kv_offsets, -1)
535-
536-
# mark any invalid entries as maximum offset
537-
q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets)
538-
kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets)
533+
q_offsets = _shift_to_left(q_offsets, B * T)
534+
kv_offsets = _shift_to_left(kv_offsets, B * S)
539535

540536
# multiply by stride_per_token to get correct offsets
541537
# do it here because real stride changes after sharding

0 commit comments

Comments
 (0)