Skip to content

Commit 9e98932

Browse files
yashk2810jax authors
authored andcommitted
Make sure we don't return GSPMDSharding in compiled.input_shardings
PiperOrigin-RevId: 624343180
1 parent 0941560 commit 9e98932

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,23 +2574,23 @@ def _get_out_sharding_from_orig_sharding(
25742574
out.append(o)
25752575
return out
25762576

2577-
def maybe_get_orig_out_sharding(
2578-
in_shardings, out_shardings, in_avals, out_avals):
2579-
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in out_shardings):
2580-
return out_shardings
2577+
def maybe_recover_user_shardings(
2578+
old_shardings, new_shardings, old_avals, new_avals):
2579+
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
2580+
return new_shardings
25812581

25822582
orig_in_s = None
25832583
orig_aval = None
2584-
for oi, aval in safe_zip(in_shardings, in_avals):
2584+
for oi, aval in safe_zip(old_shardings, old_avals):
25852585
if type(oi) in _orig_out_sharding_handlers:
25862586
orig_in_s = oi
25872587
orig_aval = aval
25882588
break
25892589
if orig_in_s is not None:
25902590
return _get_out_sharding_from_orig_sharding(
2591-
out_shardings, out_avals, orig_in_s, orig_aval)
2591+
new_shardings, new_avals, orig_in_s, orig_aval)
25922592

2593-
return out_shardings
2593+
return new_shardings
25942594

25952595

25962596
def _get_layouts_from_executable(
@@ -2744,6 +2744,10 @@ def _maybe_get_and_check_in_shardings(
27442744
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
27452745
"(User sharding)")
27462746
new_in_shardings.append(orig)
2747+
2748+
new_in_shardings = maybe_recover_user_shardings(
2749+
in_shardings, new_in_shardings, global_in_avals, global_in_avals)
2750+
27472751
return new_in_shardings
27482752

27492753

@@ -2921,7 +2925,7 @@ def from_hlo(name: str,
29212925
assert all(i is None for i in in_layouts)
29222926
assert all(o is None for o in out_layouts)
29232927

2924-
out_shardings = maybe_get_orig_out_sharding(
2928+
out_shardings = maybe_recover_user_shardings(
29252929
in_shardings, out_shardings, global_in_avals, global_out_avals)
29262930

29272931
out_shardings = finalize_out_shardings(out_shardings, da)

tests/pjit_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3999,6 +3999,19 @@ def f(x, y, z, a, b):
39993999
self.assertArraysEqual(out4, np_inp * 3)
40004000
self.assertArraysEqual(out5, np_inp.T)
40014001

4002+
def test_input_shardings_aot(self):
4003+
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
4004+
np_inp = np.arange(16).reshape(8, 2)
4005+
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
4006+
4007+
@jax.jit
4008+
def f(x, y):
4009+
return x * 2, y.T
4010+
4011+
arg_shardings, _ = f.lower(arr, np_inp).compile().input_shardings
4012+
for s in arg_shardings:
4013+
self.assertIsInstance(s, NamedSharding)
4014+
40024015
def test_parameter_tupled_jit(self):
40034016
if not jtu.test_device_matches(["tpu"]):
40044017
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')

0 commit comments

Comments
 (0)