Skip to content

Commit ac41032

Browse files
author
jax authors
committed
[XLA:MHLO->HLO] Allow partially-set parameter tuple sharding to exist by filling in the missing sharding elements with replicated sharding. (This is what is done for the missing shardings in the result tuple.)
Before this change, if an element of a tuple parameter did not have a sharding, MHLO->HLO conversion dropped the existing annotations on the parameter. This issue caused the disappearing of the parameter sharding for a model, which then resulted in an OOM. PiperOrigin-RevId: 615917181
1 parent 9a00721 commit ac41032

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,8 +2076,8 @@ def lower_sharding_computation(
20762076
any(not is_unspecified(o) for o in out_shardings))
20772077

20782078
gs = GSPMDSharding.get_replicated(device_assignment)
2079-
# if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
2080-
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
2079+
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
2080+
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
20812081

20822082
da_object = _create_da_object(tuple(device_assignment))
20832083

tests/pjit_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,6 +3942,22 @@ def f(x, y, z, a, b):
39423942
self.assertArraysEqual(out4, np_inp * 3)
39433943
self.assertArraysEqual(out5, np_inp.T)
39443944

3945+
def test_parameter_tupled_jit(self):
3946+
if not jtu.test_device_matches(["tpu"]):
3947+
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')
3948+
3949+
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
3950+
s = NamedSharding(mesh, P('x'))
3951+
3952+
@jax.jit
3953+
def f(*args):
3954+
return args * 2
3955+
3956+
inp = np.arange(8)
3957+
arr = jax.device_put(inp, s)
3958+
inps = [arr, *[inp] * 2001]
3959+
f(inps) # doesn't crash
3960+
39453961

39463962
class TempSharding(Sharding):
39473963

0 commit comments

Comments
 (0)