Skip to content

Commit f569031

Browse files
junwhanahnjax authors
authored andcommitted
Reverts 55394a0
PiperOrigin-RevId: 616201321
1 parent 94122f8 commit f569031

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

jax/_src/api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2958,7 +2958,29 @@ def try_to_block(x):
29582958
return x.block_until_ready()
29592959
except AttributeError:
29602960
return x
2961-
return tree_map(try_to_block, x)
2961+
2962+
if xla_extension_version < 246:
2963+
return tree_map(try_to_block, x)
2964+
2965+
arrays = []
2966+
for leaf in tree_leaves(x):
2967+
if isinstance(leaf, array.ArrayImpl):
2968+
arrays.append(leaf)
2969+
else:
2970+
try_to_block(leaf)
2971+
2972+
if not arrays:
2973+
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
2974+
# jax.Array.
2975+
pass
2976+
elif len(arrays) == 1:
2977+
# Fast path for single array.
2978+
try_to_block(arrays[0])
2979+
else:
2980+
# Optimized for multiple arrays.
2981+
xc.batched_block_until_ready(arrays)
2982+
2983+
return x
29622984

29632985

29642986
def clear_backends():

tests/api_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,6 +2374,20 @@ def test_block_until_ready_function(self):
23742374
self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False)
23752375
self.assertAllClose(pytree[1], np.ones(3), check_dtypes=False)
23762376

2377+
def test_block_until_ready_numpy_arrays(self):
2378+
pytree = (np.ones(1), np.ones(2))
2379+
pytree = jax.block_until_ready(pytree)
2380+
self.assertAllClose(pytree[0], np.ones(1), check_dtypes=False)
2381+
self.assertAllClose(pytree[1], np.ones(2), check_dtypes=False)
2382+
2383+
def test_block_until_ready_mixed(self):
2384+
pytree = (device_put(1.), device_put(2.), np.ones(3), 4)
2385+
pytree = jax.block_until_ready(pytree)
2386+
self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False)
2387+
self.assertAllClose(pytree[1], jnp.array(2.), check_dtypes=False)
2388+
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
2389+
self.assertEqual(pytree[3], 4)
2390+
23772391
def test_devicearray_weakref_friendly(self):
23782392
x = device_put(1.)
23792393
y = weakref.ref(x)

0 commit comments

Comments
 (0)