@@ -2439,7 +2439,7 @@ def make_jaxpr_f(*args, **kwargs):
2439
2439
2440
2440
def _infer_src_sharding (src , x ) -> Sharding | None :
2441
2441
if src is not None :
2442
- return src
2442
+ return src # type: ignore
2443
2443
if isinstance (x , array .ArrayImpl ):
2444
2444
return x .sharding
2445
2445
elif isinstance (x , core .Tracer ):
@@ -2493,21 +2493,20 @@ def device_put(
2493
2493
isinstance (device , (xc .Device , Sharding , TransferToMemoryKind ))) and
2494
2494
(src is None or
2495
2495
isinstance (src , (xc .Device , Sharding , TransferToMemoryKind )))):
2496
- for leaf in tree_leaves ( x ):
2497
- _check_sharding (shaped_abstractify (leaf ), s = device )
2498
- return tree_map (
2499
- lambda y : dispatch . device_put_p . bind (
2500
- y , device = device , src = _infer_src_sharding ( src , y )) , x )
2496
+ def _map ( y ):
2497
+ _check_sharding (shaped_abstractify (y ), s = device )
2498
+ return dispatch . device_put_p . bind (
2499
+ y , device = device , src = _infer_src_sharding ( src , y ))
2500
+ return tree_map ( _map , x )
2501
2501
2502
2502
x_flat , treedef = tree_flatten (x )
2503
2503
device_flat = flatten_axes ("device_put device" , treedef , device )
2504
2504
src_flat = flatten_axes ("device_put source" , treedef , src )
2505
- for x_leaf , device_leaf in zip (x_flat , device_flat ):
2506
- _check_sharding (shaped_abstractify (x_leaf ), device_leaf )
2507
- out_flat = [
2508
- dispatch .device_put_p .bind (xf , device = d , src = _infer_src_sharding (s , xf ))
2509
- for xf , d , s in zip (x_flat , device_flat , src_flat )
2510
- ]
2505
+ out_flat = []
2506
+ for xf , d , s in zip (x_flat , device_flat , src_flat ):
2507
+ _check_sharding (shaped_abstractify (xf ), d )
2508
+ out_flat .append (dispatch .device_put_p .bind (
2509
+ xf , device = d , src = _infer_src_sharding (s , xf )))
2511
2510
return tree_unflatten (treedef , out_flat )
2512
2511
2513
2512
0 commit comments