Skip to content

Commit 4a6ee78

Browse files
author
jax authors
committed
[XLA] Clear derived instruction's sharding only if shapes are incompatible.
When AlgebraicSimplifier calls `dot->SetupDerivedInstruction(new_lhs);` in HandleDot, lhs sharding was cleared when dot didn't have a sharding. With this CL, lhs preserves its sharding because the condition for clearing the sharding is narrowed down to only when shapes are incompatible. Fixes #20710 PiperOrigin-RevId: 624731930
1 parent 7d0ba76 commit 4a6ee78

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/pjit_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4028,6 +4028,33 @@ def f(*args):
40284028
inps = [arr, *[inp] * 2001]
40294029
f(inps) # doesn't crash
40304030

4031+
def test_spmd_preserves_input_sharding_vmap_grad(self):
4032+
# https://github.com/google/jax/issues/20710
4033+
n_devices = jax.device_count()
4034+
sharding = PositionalSharding(jax.devices())
4035+
4036+
def model(params, x):
4037+
return x @ params
4038+
4039+
feature_dim = 3
4040+
batch_size_total = 8
4041+
4042+
# Get example data
4043+
x = jnp.ones((batch_size_total, feature_dim))
4044+
params = jnp.ones(feature_dim)
4045+
4046+
# Shard data, replicate params
4047+
x = jax.device_put(x, sharding.reshape(n_devices, 1))
4048+
params = jax.device_put(params, sharding.replicate(axis=0))
4049+
4050+
model(params, x) # doesn't crash
4051+
4052+
jax.vmap(model, in_axes=(None, 0))(params, x) # doesn't crash
4053+
4054+
jax.grad(lambda p: model(p, x).sum())(params) # doesn't crash
4055+
4056+
jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash
4057+
40314058

40324059
class TempSharding(Sharding):
40334060

0 commit comments

Comments
 (0)