24
24
25
25
import collections
26
26
from collections .abc import Generator , Hashable , Iterable , Sequence
27
- from functools import partial
27
+ from functools import partial , lru_cache
28
28
import inspect
29
29
import math
30
30
import typing
@@ -2451,9 +2451,9 @@ def _infer_src_sharding(src, x) -> Sharding | None:
2451
2451
2452
2452
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
2453
2453
# to check if shardings are compatible with the input.
2454
- def _check_sharding (x , s ):
2454
+ @lru_cache (maxsize = 2048 )
2455
+ def _check_sharding (aval , s ):
2455
2456
if isinstance (s , Sharding ):
2456
- aval = shaped_abstractify (x )
2457
2457
if isinstance (aval , core .AbstractToken ):
2458
2458
aval = core .token_shaped_array
2459
2459
if isinstance (s , XLACompatibleSharding ) and not isinstance (s , PmapSharding ):
@@ -2494,7 +2494,7 @@ def device_put(
2494
2494
(src is None or
2495
2495
isinstance (src , (xc .Device , Sharding , TransferToMemoryKind )))):
2496
2496
for leaf in tree_leaves (x ):
2497
- _check_sharding (leaf , s = device )
2497
+ _check_sharding (shaped_abstractify ( leaf ) , s = device )
2498
2498
return tree_map (
2499
2499
lambda y : dispatch .device_put_p .bind (
2500
2500
y , device = device , src = _infer_src_sharding (src , y )), x )
@@ -2503,7 +2503,7 @@ def device_put(
2503
2503
device_flat = flatten_axes ("device_put device" , treedef , device )
2504
2504
src_flat = flatten_axes ("device_put source" , treedef , src )
2505
2505
for x_leaf , device_leaf in zip (x_flat , device_flat ):
2506
- _check_sharding (x_leaf , device_leaf )
2506
+ _check_sharding (shaped_abstractify ( x_leaf ) , device_leaf )
2507
2507
out_flat = [
2508
2508
dispatch .device_put_p .bind (xf , device = d , src = _infer_src_sharding (s , xf ))
2509
2509
for xf , d , s in zip (x_flat , device_flat , src_flat )
0 commit comments