@@ -73,12 +73,11 @@ def setUp(self):
73
73
@jtu .sample_product (
74
74
shape = all_shapes ,
75
75
dtype = dlpack_dtypes ,
76
- copy = [False , True , None ]
76
+ copy = [False , True , None ],
77
+ use_stream = [False , True ],
77
78
)
78
79
@jtu .run_on_devices ("gpu" )
79
- def testJaxRoundTrip (self , shape , dtype , copy ):
80
- if xb .using_pjrt_c_api ():
81
- self .skipTest ("DLPack support is incomplete in the PJRT C API" ) # TODO(skyewm)
80
+ def testJaxRoundTrip (self , shape , dtype , copy , use_stream ):
82
81
rng = jtu .rand_default (self .rng ())
83
82
np = rng (shape , dtype )
84
83
@@ -91,7 +90,11 @@ def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
91
90
device = jax .devices ("gpu" )[0 ]
92
91
y = jax .device_put (x , device )
93
92
dl_device = y .__dlpack_device__ ()
94
- dlpack = jax .dlpack .to_dlpack (y , copy = copy )
93
+ if use_stream :
94
+ stream = tuple (y .devices ())[0 ].get_stream_for_external_ready_events ()
95
+ dlpack = jax .dlpack .to_dlpack (y , copy = copy , stream = stream )
96
+ else :
97
+ dlpack = jax .dlpack .to_dlpack (y , copy = copy )
95
98
z = jax .dlpack .from_dlpack (dlpack )
96
99
97
100
self .assertEqual (z .devices (), {device })
0 commit comments