Skip to content

Commit 366f248

Browse files
authored
Set scoped vmem for paged attention (#8988)
1 parent 083934f commit 366f248

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/test_ragged_paged_attention_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def test_paged_attention_varlen_comprehensive(
254254
num_pages: int,
255255
num_queries_per_block: int,
256256
):
257+
if jtu.is_device_tpu(version=5, variant="e"):
258+
self.skipTest(
259+
"TPU v5e has small VMEM. It will run into VMEM OOM. Skip the test.")
257260
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
258261
self.skipTest(
259262
"TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")
@@ -285,6 +288,9 @@ def test_paged_attention_varlen_with_padding_comprehensive(
285288
num_pages: int,
286289
num_queries_per_block: int,
287290
):
291+
if jtu.is_device_tpu(version=5, variant="e"):
292+
self.skipTest(
293+
"TPU v5e has small VMEM. It will run into VMEM OOM. Skip the test.")
288294
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
289295
self.skipTest(
290296
"TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")

torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,9 @@ def next_kv_blk_page_indices_index_map(kv_head_idx, logical_q_blk_idx,
961961
"arbitrary",
962962
"arbitrary",
963963
"arbitrary",
964-
)),
964+
),
965+
vmem_limit_bytes=64 * 1024 * 1024,
966+
),
965967
out_shape=out_shape,
966968
)
967969
buffer_index = jnp.zeros((1,), jnp.int32)

0 commit comments

Comments
 (0)