Skip to content

Commit 3dbbfef

Browse files
Jieying Luojax authors
authored andcommitted
[PJRT C API] Add a helper method to check whether the backend is cloud TPU built after certain date.
Skip tests that are not intended to work with older version libtpu. PiperOrigin-RevId: 610892754
1 parent fdbee31 commit 3dbbfef

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

jax/_src/test_util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from collections.abc import Generator, Iterable, Sequence
1717
from contextlib import contextmanager, ExitStack
18+
import datetime
1819
import inspect
1920
import io
2021
import functools
@@ -366,6 +367,24 @@ def is_device_cuda():
366367
def is_cloud_tpu():
367368
return running_in_cloud_tpu_vm
368369

370+
# Returns True if it is not cloud TPU. If it is cloud TPU, returns True if it is
371+
# built at least `date``.
372+
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
373+
# libtpu contains the XLA version, remove using built time to skip tests.
374+
def if_cloud_tpu_at_least(date: datetime.date):
375+
if not is_cloud_tpu():
376+
return True
377+
# The format of Cloud TPU platform_version is like:
378+
# PJRT C API
379+
# TFRT TPU v2
380+
# Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722
381+
platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1]
382+
results = re.findall(r'\(.*?\)', platform_version)
383+
if len(results) != 1:
384+
return True
385+
build_date = date.fromtimestamp(int(results[0][1:-1]))
386+
return build_date >= date
387+
369388
def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
370389
pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version()
371390
if pjrt_c_api_versions is None:

tests/memories_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datetime
1516
import functools
1617
import math
1718
from absl.testing import absltest
@@ -1079,6 +1080,8 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
10791080
def setUp(self):
10801081
if not jtu.test_device_matches(["tpu"]):
10811082
self.skipTest("Memories do not work on CPU and GPU backends yet.")
1083+
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
1084+
self.skipTest("Memories do not work on Cloud TPU older than 2024/02/23.")
10821085
super().setUp()
10831086

10841087
def test_remat_jaxpr_offloadable(self):

tests/pallas/pallas_call_tpu_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Test TPU-specific extensions to pallas_call."""
1616

17+
import datetime
1718
import functools
1819
from absl.testing import absltest
1920
from absl.testing import parameterized
@@ -48,6 +49,8 @@ def setUp(self):
4849
super().setUp()
4950
if not self.interpret and jtu.device_under_test() != 'tpu':
5051
self.skipTest('Only interpret mode supported on non-TPU')
52+
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 10)):
53+
self.skipTest('Does not work on Cloud TPU older than 2024/02/10.')
5154

5255
def pallas_call(self, *args, **kwargs):
5356
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
@@ -343,6 +346,8 @@ def dynamic_kernel(steps):
343346

344347
# TODO(apaszke): Add tests for scalar_prefetch too
345348
def test_dynamic_grid_scalar_input(self):
349+
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 14)):
350+
self.skipTest('Does not work on Cloud TPU older than 2024/02/14.')
346351
shape = (8, 128)
347352
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
348353

@@ -436,6 +441,9 @@ def dynamic_kernel(x, steps):
436441
)
437442

438443
def test_num_programs(self):
444+
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
445+
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
446+
439447
def kernel(y_ref):
440448
y_ref[0, 0] = pl.num_programs(0)
441449

@@ -451,6 +459,9 @@ def dynamic_kernel(steps):
451459
self.assertEqual(dynamic_kernel(4), 8)
452460

453461
def test_num_programs_block_spec(self):
462+
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
463+
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
464+
454465
def kernel(x_ref, y_ref):
455466
y_ref[...] = x_ref[...]
456467

tests/shard_alike_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datetime
1516
import os
1617

1718
import jax
@@ -66,6 +67,8 @@ def setUp(self):
6667
super().setUp()
6768
if xla_extension_version < 227:
6869
self.skipTest('Requires xla_extension_version >= 227')
70+
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
71+
self.skipTest("Requires Cloud TPU older than 2024/02/23.")
6972

7073
def test_basic(self):
7174
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))

0 commit comments

Comments
 (0)