Skip to content

Commit 1b1c6e7

Browse files
Jieying Luojax authors
authored andcommitted
Enable some more C API tests.
PiperOrigin-RevId: 627065492
1 parent 667a0c1 commit 1b1c6e7

File tree

2 files changed

+0
-8
lines changed

2 files changed

+0
-8
lines changed

tests/pytorch_interoperability_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ def testJaxToTorch(self, shape, dtype):
8888

8989
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
9090
def testJaxArrayToTorch(self, shape, dtype):
91-
if xla_bridge.using_pjrt_c_api():
92-
self.skipTest("DLPack support is incomplete in the PJRT C API")
9391
if not config.enable_x64.value and dtype in [
9492
jnp.int64,
9593
jnp.float64,
@@ -111,8 +109,6 @@ def testJaxArrayToTorch(self, shape, dtype):
111109
self.assertAllClose(np, y.cpu().numpy())
112110

113111
def testTorchToJaxInt64(self):
114-
if xla_bridge.using_pjrt_c_api():
115-
self.skipTest("DLPack support is incomplete in the PJRT C API")
116112
# See https://github.com/google/jax/issues/11895
117113
x = jax.dlpack.from_dlpack(
118114
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
@@ -121,8 +117,6 @@ def testTorchToJaxInt64(self):
121117

122118
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
123119
def testTorchToJax(self, shape, dtype):
124-
if xla_bridge.using_pjrt_c_api():
125-
self.skipTest("DLPack support is incomplete in the PJRT C API")
126120
if not config.enable_x64.value and dtype in [
127121
jnp.int64,
128122
jnp.float64,

tests/shard_map_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,8 +1991,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
19911991
def skip_if_custom_partitioning_not_supported(self):
19921992
if jtu.is_cloud_tpu():
19931993
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
1994-
if xla_bridge.using_pjrt_c_api():
1995-
raise unittest.SkipTest('custom partitioning not implemented in PJRT C API')
19961994

19971995
def test_custom_partitioning(self):
19981996
self.skip_if_custom_partitioning_not_supported()

0 commit comments

Comments
 (0)