File tree Expand file tree Collapse file tree 2 files changed +37
-1
lines changed Expand file tree Collapse file tree 2 files changed +37
-1
lines changed Original file line number Diff line number Diff line change @@ -2958,7 +2958,29 @@ def try_to_block(x):
2958
2958
return x .block_until_ready ()
2959
2959
except AttributeError :
2960
2960
return x
2961
- return tree_map (try_to_block , x )
2961
+
2962
+ if xla_extension_version < 246 :
2963
+ return tree_map (try_to_block , x )
2964
+
2965
+ arrays = []
2966
+ for leaf in tree_leaves (x ):
2967
+ if isinstance (leaf , array .ArrayImpl ):
2968
+ arrays .append (leaf )
2969
+ else :
2970
+ try_to_block (leaf )
2971
+
2972
+ if not arrays :
2973
+ # `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
2974
+ # jax.Array.
2975
+ pass
2976
+ elif len (arrays ) == 1 :
2977
+ # Fast path for single array.
2978
+ try_to_block (arrays [0 ])
2979
+ else :
2980
+ # Optimized for multiple arrays.
2981
+ xc .batched_block_until_ready (arrays )
2982
+
2983
+ return x
2962
2984
2963
2985
2964
2986
def clear_backends ():
Original file line number Diff line number Diff line change @@ -2374,6 +2374,20 @@ def test_block_until_ready_function(self):
2374
2374
self .assertAllClose (pytree [0 ], jnp .array (1. ), check_dtypes = False )
2375
2375
self .assertAllClose (pytree [1 ], np .ones (3 ), check_dtypes = False )
2376
2376
2377
+ def test_block_until_ready_numpy_arrays (self ):
2378
+ pytree = (np .ones (1 ), np .ones (2 ))
2379
+ pytree = jax .block_until_ready (pytree )
2380
+ self .assertAllClose (pytree [0 ], np .ones (1 ), check_dtypes = False )
2381
+ self .assertAllClose (pytree [1 ], np .ones (2 ), check_dtypes = False )
2382
+
2383
+ def test_block_until_ready_mixed (self ):
2384
+ pytree = (device_put (1. ), device_put (2. ), np .ones (3 ), 4 )
2385
+ pytree = jax .block_until_ready (pytree )
2386
+ self .assertAllClose (pytree [0 ], jnp .array (1. ), check_dtypes = False )
2387
+ self .assertAllClose (pytree [1 ], jnp .array (2. ), check_dtypes = False )
2388
+ self .assertAllClose (pytree [2 ], np .ones (3 ), check_dtypes = False )
2389
+ self .assertEqual (pytree [3 ], 4 )
2390
+
2377
2391
def test_devicearray_weakref_friendly (self ):
2378
2392
x = device_put (1. )
2379
2393
y = weakref .ref (x )
You can’t perform that action at this time.
0 commit comments