Skip to content

Commit 1837b43

Browse files
yashk2810jax authors
authored andcommitted
Merge some loops in device_put since it's trivial to do so
PiperOrigin-RevId: 626546322
1 parent 0943eb3 commit 1837b43

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

jax/_src/api.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,7 +2439,7 @@ def make_jaxpr_f(*args, **kwargs):
24392439

24402440
def _infer_src_sharding(src, x) -> Sharding | None:
24412441
if src is not None:
2442-
return src
2442+
return src # type: ignore
24432443
if isinstance(x, array.ArrayImpl):
24442444
return x.sharding
24452445
elif isinstance(x, core.Tracer):
@@ -2493,21 +2493,20 @@ def device_put(
24932493
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
24942494
(src is None or
24952495
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
2496-
for leaf in tree_leaves(x):
2497-
_check_sharding(shaped_abstractify(leaf), s=device)
2498-
return tree_map(
2499-
lambda y: dispatch.device_put_p.bind(
2500-
y, device=device, src=_infer_src_sharding(src, y)), x)
2496+
def _map(y):
2497+
_check_sharding(shaped_abstractify(y), s=device)
2498+
return dispatch.device_put_p.bind(
2499+
y, device=device, src=_infer_src_sharding(src, y))
2500+
return tree_map(_map, x)
25012501

25022502
x_flat, treedef = tree_flatten(x)
25032503
device_flat = flatten_axes("device_put device", treedef, device)
25042504
src_flat = flatten_axes("device_put source", treedef, src)
2505-
for x_leaf, device_leaf in zip(x_flat, device_flat):
2506-
_check_sharding(shaped_abstractify(x_leaf), device_leaf)
2507-
out_flat = [
2508-
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
2509-
for xf, d, s in zip(x_flat, device_flat, src_flat)
2510-
]
2505+
out_flat = []
2506+
for xf, d, s in zip(x_flat, device_flat, src_flat):
2507+
_check_sharding(shaped_abstractify(xf), d)
2508+
out_flat.append(dispatch.device_put_p.bind(
2509+
xf, device=d, src=_infer_src_sharding(s, xf)))
25112510
return tree_unflatten(treedef, out_flat)
25122511

25132512

0 commit comments

Comments
 (0)