Skip to content

Commit 10dbfcf

Browse files
voutcnjax authors
authored andcommitted
Fix incorrect sequence length in batch megacore mode and enable megacore tests which were incorrectly disabled before.
Also configure sequence lengths in the unit test to cover edge cases (zero length, divisible/non-divisible by block size). PiperOrigin-RevId: 623657472
1 parent 4d9efff commit 10dbfcf

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,14 @@ def body(i, _):
281281
return ()
282282

283283
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
284-
lax.fori_loop(0, lax.div(lengths_ref[b] + bk - 1, bk), body, ())
284+
285+
if megacore_mode == "batch":
286+
num_cores = pl.num_programs(0)
287+
length = lengths_ref[b * num_cores + core_index]
288+
else:
289+
length = lengths_ref[b]
290+
291+
lax.fori_loop(0, lax.div(length + bk - 1, bk), body, ())
285292

286293

287294
@functools.partial(
@@ -304,7 +311,7 @@ def paged_attention(
304311
pages_per_compute_block: int,
305312
megacore_mode: Optional[str] = None,
306313
inline_seq_dim: bool = True,
307-
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
314+
) -> jax.Array:
308315
"""Paged grouped query attention.
309316
310317
Args:

tests/pallas/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ jax_test(
217217
shard_count = 2,
218218
tags = [
219219
"noasan", # Times out.
220+
"nomsan", # Times out.
221+
"notsan", # Times out.
220222
],
221223
deps = [
222224
"//jax:pallas_tpu_ops",

tests/pallas/paged_attention_kernel_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _grouped_query_attention_reference(q, k, v, lengths):
8383

8484

8585
def _megacore_enabled():
86-
return jax.devices()[0].device_kind == "TPU V4" and jtu.is_device_tpu(
86+
return jax.devices()[0].device_kind == "TPU v4" or jtu.is_device_tpu(
8787
version=5, variant="p"
8888
)
8989

@@ -114,15 +114,12 @@ def test_paged_attention(
114114
if not jtu.is_device_tpu_at_least(4):
115115
self.skipTest("Only supports TPU generation 4 or above")
116116
if megacore_mode and not _megacore_enabled():
117-
self.skipTest("Megacore is only available on TPU v4 and TPU v5p")
117+
self.skipTest("Megacore is only available on TPU v4 or TPU v5p")
118118
if num_kv_heads % 2 != 0 and megacore_mode == "kv_head":
119119
self.skipTest("Skip kv_head megacore mode when num_kv_heads is odd")
120-
batch_size = 4
121120
max_kv_len = 2048
122121
block_size = 512
123-
seq_lens = np.asarray(
124-
[max_kv_len // batch_size * (i + 1) for i in range(batch_size)]
125-
)
122+
seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048])
126123
q, k_pages, v_pages, page_indices = _generate_qkv(
127124
seq_lens,
128125
page_size,
@@ -151,7 +148,10 @@ def test_paged_attention(
151148
else:
152149
atol, rtol = 1e-1, 1e-1
153150
np.testing.assert_allclose(
154-
o.astype(jnp.float32), o_ref.astype(jnp.float32), atol=atol, rtol=rtol
151+
o[np.where(seq_lens > 0)].astype(jnp.float32),
152+
o_ref[np.where(seq_lens > 0)].astype(jnp.float32),
153+
atol=atol,
154+
rtol=rtol,
155155
)
156156

157157

0 commit comments

Comments
 (0)