Skip to content

Commit 837f0bb

Browse files
yashk2810jax authors
authored andcommitted
Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
1 parent 8fec8a6 commit 837f0bb

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

jax/_src/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import collections
2626
from collections.abc import Generator, Hashable, Iterable, Sequence
27-
from functools import partial
27+
from functools import partial, lru_cache
2828
import inspect
2929
import math
3030
import typing
@@ -2451,9 +2451,9 @@ def _infer_src_sharding(src, x) -> Sharding | None:
24512451

24522452
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
24532453
# to check if shardings are compatible with the input.
2454-
def _check_sharding(x, s):
2454+
@lru_cache(maxsize=2048)
2455+
def _check_sharding(aval, s):
24552456
if isinstance(s, Sharding):
2456-
aval = shaped_abstractify(x)
24572457
if isinstance(aval, core.AbstractToken):
24582458
aval = core.token_shaped_array
24592459
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
@@ -2494,7 +2494,7 @@ def device_put(
24942494
(src is None or
24952495
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
24962496
for leaf in tree_leaves(x):
2497-
_check_sharding(leaf, s=device)
2497+
_check_sharding(shaped_abstractify(leaf), s=device)
24982498
return tree_map(
24992499
lambda y: dispatch.device_put_p.bind(
25002500
y, device=device, src=_infer_src_sharding(src, y)), x)
@@ -2503,7 +2503,7 @@ def device_put(
25032503
device_flat = flatten_axes("device_put device", treedef, device)
25042504
src_flat = flatten_axes("device_put source", treedef, src)
25052505
for x_leaf, device_leaf in zip(x_flat, device_flat):
2506-
_check_sharding(x_leaf, device_leaf)
2506+
_check_sharding(shaped_abstractify(x_leaf), device_leaf)
25072507
out_flat = [
25082508
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
25092509
for xf, d, s in zip(x_flat, device_flat, src_flat)

0 commit comments

Comments
 (0)