@@ -88,8 +88,6 @@ def testJaxToTorch(self, shape, dtype):
88
88
89
89
@jtu .sample_product (shape = all_shapes , dtype = torch_dtypes )
90
90
def testJaxArrayToTorch (self , shape , dtype ):
91
- if xla_bridge .using_pjrt_c_api ():
92
- self .skipTest ("DLPack support is incomplete in the PJRT C API" )
93
91
if not config .enable_x64 .value and dtype in [
94
92
jnp .int64 ,
95
93
jnp .float64 ,
@@ -111,8 +109,6 @@ def testJaxArrayToTorch(self, shape, dtype):
111
109
self .assertAllClose (np , y .cpu ().numpy ())
112
110
113
111
def testTorchToJaxInt64 (self ):
114
- if xla_bridge .using_pjrt_c_api ():
115
- self .skipTest ("DLPack support is incomplete in the PJRT C API" )
116
112
# See https://github.com/google/jax/issues/11895
117
113
x = jax .dlpack .from_dlpack (
118
114
torch .utils .dlpack .to_dlpack (torch .ones ((2 , 3 ), dtype = torch .int64 )))
@@ -121,8 +117,6 @@ def testTorchToJaxInt64(self):
121
117
122
118
@jtu .sample_product (shape = all_shapes , dtype = torch_dtypes )
123
119
def testTorchToJax (self , shape , dtype ):
124
- if xla_bridge .using_pjrt_c_api ():
125
- self .skipTest ("DLPack support is incomplete in the PJRT C API" )
126
120
if not config .enable_x64 .value and dtype in [
127
121
jnp .int64 ,
128
122
jnp .float64 ,
0 commit comments