Skip to content

Commit f926b3c

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow dropping into explicit_axes with a reshard in an outer vmap with spmd_axis_name set.
PiperOrigin-RevId: 782149747
1 parent 72198ee commit f926b3c

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

jax/_src/pjit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2945,7 +2945,6 @@ def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
29452945
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)
29462946

29472947
def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding):
2948-
assert axis_data.spmd_name is None
29492948
x, = vals_in
29502949
d, = dims_in
29512950
vmapped_dst_sharding = batching.get_sharding_for_vmap(

tests/pjit_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8431,6 +8431,27 @@ def f(x):
84318431
else:
84328432
self.assertEqual(lowered_text.count('unspecified_dims=[0,1]'), 3)
84338433

8434+
@jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2)
8435+
def test_vmap_spmd_axis_name_explicit_axes_inside(self, mesh):
8436+
np_inp = np.arange(16).reshape(2, 8)
8437+
arr1 = jax.device_put(np_inp, P())
8438+
arr2 = jax.device_put(np_inp, P())
8439+
8440+
@jax.jit
8441+
def f(x, y):
8442+
@partial(explicit_axes, in_sharding=(P('y'), P('y')))
8443+
def g(a, b):
8444+
self.assertEqual(a.aval.sharding.spec, P('y'))
8445+
self.assertEqual(b.aval.sharding.spec, P('y'))
8446+
a = reshard(a, P())
8447+
self.assertEqual(a.aval.sharding.spec, P(None))
8448+
out = a * b
8449+
self.assertEqual(out.aval.sharding.spec, P('y'))
8450+
return out
8451+
return g(x, y)
8452+
8453+
out = jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr1, arr2) # doesn't crash
8454+
84348455

84358456
@jtu.pytest_mark_if_available('multiaccelerator')
84368457
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)