@@ -76,8 +76,6 @@ def setUp(self):
76
76
gpu = [False , True ],
77
77
)
78
78
def testJaxRoundTrip (self , shape , dtype , gpu ):
79
- if xb .using_pjrt_c_api ():
80
- self .skipTest ("DLPack support is incomplete in the PJRT C API" ) # TODO(skyewm)
81
79
rng = jtu .rand_default (self .rng ())
82
80
np = rng (shape , dtype )
83
81
if gpu and jtu .test_device_matches (["cpu" ]):
@@ -119,8 +117,6 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu):
119
117
)
120
118
@unittest .skipIf (not tf , "Test requires TensorFlow" )
121
119
def testTensorFlowToJax (self , shape , dtype ):
122
- if xb .using_pjrt_c_api ():
123
- self .skipTest ("DLPack support is incomplete in the PJRT C API" )
124
120
if (not config .enable_x64 .value and
125
121
dtype in [jnp .int64 , jnp .uint64 , jnp .float64 ]):
126
122
raise self .skipTest ("x64 types are disabled by jax_enable_x64" )
@@ -163,8 +159,6 @@ def testJaxToTensorFlow(self, shape, dtype):
163
159
164
160
@unittest .skipIf (not tf , "Test requires TensorFlow" )
165
161
def testTensorFlowToJaxInt64 (self ):
166
- if xb .using_pjrt_c_api ():
167
- self .skipTest ("DLPack support is incomplete in the PJRT C API" )
168
162
# See https://github.com/google/jax/issues/11895
169
163
x = jax .dlpack .from_dlpack (
170
164
tf .experimental .dlpack .to_dlpack (tf .ones ((2 , 3 ), tf .int64 )))
0 commit comments