Skip to content

Commit 5981df7

Browse files
superbobryjax authors
authored andcommitted
Removed unnecessary jax.tree.map calls from *_callback_impl functions
jax.device_put works for any PyTree. PiperOrigin-RevId: 626510762
1 parent 52f5f70 commit 5981df7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

jax/_src/callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def pure_callback_impl(
7373
):
7474
del sharding, vectorized, result_avals
7575
cpu_device, *_ = jax.local_devices(backend="cpu")
76-
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
76+
args = jax.device_put(args, cpu_device)
7777
with jax.default_device(cpu_device):
7878
try:
7979
return tree_util.tree_map(np.asarray, callback(*args))
@@ -401,7 +401,7 @@ def io_callback_impl(
401401
):
402402
del result_avals, sharding, ordered
403403
cpu_device, *_ = jax.local_devices(backend="cpu")
404-
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
404+
args = jax.device_put(args, cpu_device)
405405
with jax.default_device(cpu_device):
406406
try:
407407
return tree_util.tree_map(np.asarray, callback(*args))

0 commit comments

Comments
 (0)