File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -73,7 +73,7 @@ def pure_callback_impl(
73
73
):
74
74
del sharding , vectorized , result_avals
75
75
cpu_device , * _ = jax .local_devices (backend = "cpu" )
76
- args = tree_util . tree_map ( lambda arg : jax .device_put (arg , cpu_device ), args )
76
+ args = jax .device_put (args , cpu_device )
77
77
with jax .default_device (cpu_device ):
78
78
try :
79
79
return tree_util .tree_map (np .asarray , callback (* args ))
@@ -401,7 +401,7 @@ def io_callback_impl(
401
401
):
402
402
del result_avals , sharding , ordered
403
403
cpu_device , * _ = jax .local_devices (backend = "cpu" )
404
- args = tree_util . tree_map ( lambda arg : jax .device_put (arg , cpu_device ), args )
404
+ args = jax .device_put (args , cpu_device )
405
405
with jax .default_device (cpu_device ):
406
406
try :
407
407
return tree_util .tree_map (np .asarray , callback (* args ))
You can’t perform that action at this time.
0 commit comments