Skip to content

Commit f88139b

Browse files
Jieying Luojax authors
authored andcommitted
Add a fallback when GetDefaultLayout is unimplemented for that backend.
PiperOrigin-RevId: 622278710
1 parent 7413894 commit f88139b

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

tests/array_interoperability_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def setUp(self):
7676
gpu=[False, True],
7777
)
7878
def testJaxRoundTrip(self, shape, dtype, gpu):
79-
if xb.using_pjrt_c_api():
80-
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
8179
rng = jtu.rand_default(self.rng())
8280
np = rng(shape, dtype)
8381
if gpu and jtu.test_device_matches(["cpu"]):
@@ -119,8 +117,6 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu):
119117
)
120118
@unittest.skipIf(not tf, "Test requires TensorFlow")
121119
def testTensorFlowToJax(self, shape, dtype):
122-
if xb.using_pjrt_c_api():
123-
self.skipTest("DLPack support is incomplete in the PJRT C API")
124120
if (not config.enable_x64.value and
125121
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
126122
raise self.skipTest("x64 types are disabled by jax_enable_x64")
@@ -163,8 +159,6 @@ def testJaxToTensorFlow(self, shape, dtype):
163159

164160
@unittest.skipIf(not tf, "Test requires TensorFlow")
165161
def testTensorFlowToJaxInt64(self):
166-
if xb.using_pjrt_c_api():
167-
self.skipTest("DLPack support is incomplete in the PJRT C API")
168162
# See https://github.com/google/jax/issues/11895
169163
x = jax.dlpack.from_dlpack(
170164
tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64)))

0 commit comments

Comments
 (0)