Skip to content

Commit e5a16a0

Browse files
apaszkejax authors
authored andcommitted
Skip offloading tests when used together with an old jaxlib
PiperOrigin-RevId: 617206370
1 parent 3830b17 commit e5a16a0

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/memories_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,8 @@ def g(ys, _):
12731273
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
12741274

12751275
def test_remat_checkpoint_dots_with_no_batch_dims(self):
1276+
if not jtu.test_device_matches(["tpu"]) and xla_extension_version < 247:
1277+
self.skipTest("Test requires a newer jaxlib")
12761278
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
12771279
"device", "pinned_host")
12781280

0 commit comments

Comments
 (0)