Skip to content

Commit b2375fa

Browse files
Jieying Luojax authors
authored andcommitted
[PJRT C API] Add stream extension to support DLPack and implement this extension in CUDA plugin.
PiperOrigin-RevId: 626408630
1 parent cea36a0 commit b2375fa

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tests/array_interoperability_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,11 @@ def setUp(self):
7373
@jtu.sample_product(
7474
shape=all_shapes,
7575
dtype=dlpack_dtypes,
76-
copy=[False, True, None]
76+
copy=[False, True, None],
77+
use_stream=[False, True],
7778
)
7879
@jtu.run_on_devices("gpu")
79-
def testJaxRoundTrip(self, shape, dtype, copy):
80-
if xb.using_pjrt_c_api():
81-
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
80+
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
8281
rng = jtu.rand_default(self.rng())
8382
np = rng(shape, dtype)
8483

@@ -91,7 +90,11 @@ def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
9190
device = jax.devices("gpu")[0]
9291
y = jax.device_put(x, device)
9392
dl_device = y.__dlpack_device__()
94-
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
93+
if use_stream:
94+
stream = tuple(y.devices())[0].get_stream_for_external_ready_events()
95+
dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream)
96+
else:
97+
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
9598
z = jax.dlpack.from_dlpack(dlpack)
9699

97100
self.assertEqual(z.devices(), {device})

0 commit comments

Comments
 (0)