Skip to content

Commit 7ba811e

Browse files
pschuhjax authors
authored andcommitted
Support auto in shard_map.
- Pull mesh from NamedSharding when rewriting manual axes. - Properly set manual axes in SPMDAxisContext in shard_map. - Properly set dims as unspecified inside shard_map. PiperOrigin-RevId: 627156892
1 parent a8ee946 commit 7ba811e

File tree

4 files changed

+146
-15
lines changed

4 files changed

+146
-15
lines changed

jax/_src/pjit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,8 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
23392339
# NamedSharding. So update the NamedSharding to have the manual axes.
23402340
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
23412341
mesh = resource_env.physical_mesh
2342+
if mesh.empty and isinstance(sharding, NamedSharding):
2343+
mesh = sharding.mesh
23422344
parsed_pspec = parse_flatten_op_sharding(
23432345
sharding._to_xla_hlo_sharding(aval.ndim), mesh)[0]
23442346
sharding = NamedSharding._from_parsed_pspec(

jax/_src/sharding_impls.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,16 @@ def is_equivalent_to(self: XLACompatibleSharding, # type: ignore
136136

137137

138138
@functools.lru_cache
139-
def _check_mesh_resource_axis(mesh, parsed_pspec):
139+
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
140140
try:
141-
[mesh.shape[r] for p in parsed_pspec if p is not None
142-
for r in p]
141+
for p in parsed_pspec:
142+
if p is not None:
143+
for r in p:
144+
mesh.shape[r]
145+
if r in _manual_axes:
146+
raise ValueError(
147+
f"Axis: {r} of {parsed_pspec.get_partition_spec()} "
148+
f"is also found in manual_axes: {_manual_axes}.") from None
143149
except KeyError as e:
144150
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
145151
"undefined.") from None
@@ -184,6 +190,10 @@ def named_sharding_to_xla_hlo_sharding(
184190
axis_names = self.mesh.axis_names
185191
for manual_axis in self._manual_axes:
186192
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
193+
if xla_extension_version < 259:
194+
if manual_axis in array_mapping: # type: ignore
195+
raise ValueError(f"manual axis {repr(manual_axis)} in {repr(self)} "
196+
"cannot be used as a sharded axis")
187197

188198
replicated_mesh_axes = []
189199
for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
@@ -1105,7 +1115,7 @@ def __repr__(self):
11051115
f"sync={self.sync})")
11061116

11071117

1108-
def preprocess(mesh, spec, parsed_pspec):
1118+
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
11091119
# This split exists because you can pass `_parsed_pspec` that has been
11101120
# modified from the original. For example: Adding extra dimension to
11111121
# axis_resources for vmap handlers. In such cases you need to preserve the
@@ -1118,9 +1128,11 @@ def preprocess(mesh, spec, parsed_pspec):
11181128
PartitionSpec() if spec is None else spec,
11191129
"NamedSharding spec", allow_unconstrained_dims=True)
11201130

1121-
_check_mesh_resource_axis(mesh, parsed_pspec)
1131+
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
11221132
return parsed_pspec
11231133

1134+
# fallback for c++ .
1135+
preprocess_with_manual = preprocess
11241136

11251137
def prepare_axis_resources(axis_resources,
11261138
arg_name,

jax/experimental/shard_map.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,12 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
9393
if not isinstance(mesh, Mesh):
9494
raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its "
9595
f"second argument, but got {mesh} of type {type(mesh)}.")
96-
_check_specs(SpecErrorType.input, in_specs)
96+
if not auto.issubset(mesh.axis_names):
97+
raise ValueError(f"shard_map requires auto={auto} to be a subset of "
98+
f"mesh.axis_names={mesh.axis_names}")
99+
_check_specs(SpecErrorType.input, in_specs, auto)
97100
if not callable(out_specs):
98-
_check_specs(SpecErrorType.out, out_specs)
101+
_check_specs(SpecErrorType.out, out_specs, auto)
99102

100103
@util.wraps(f)
101104
@traceback_util.api_boundary
@@ -114,7 +117,7 @@ def wrapped(*args):
114117
def out_names_thunk():
115118
if callable(out_specs):
116119
out_specs_ = out_specs()
117-
_check_specs(SpecErrorType.out, out_specs_)
120+
_check_specs(SpecErrorType.out, out_specs_, auto)
118121
else:
119122
out_specs_ = out_specs
120123
dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves)
@@ -162,17 +165,40 @@ def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
162165

163166
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
164167

165-
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
168+
def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None:
166169
if error_type == SpecErrorType.input and specs is None:
167170
raise TypeError(
168171
"shard_map in_specs argument must be a pytree of "
169172
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
170173
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
171174
"where `P = jax.sharding.PartitionSpec`?")
172-
if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return
175+
def check_spec(p):
176+
if not isinstance(p, PartitionSpec):
177+
return False
178+
for names in p:
179+
if not isinstance(names, tuple):
180+
names = (names,)
181+
for name in names:
182+
if name in auto:
183+
return False
184+
return True
185+
if all(check_spec(p) for p in tree_leaves(specs)): return
173186
prefix = 'in' if error_type == SpecErrorType.input else 'out'
174187
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
175188
for key, x in generate_key_paths(specs) if not isinstance(x, P)]
189+
if not msgs:
190+
for key, p in generate_key_paths(specs):
191+
for names in p:
192+
if not isinstance(names, tuple):
193+
names = (names,)
194+
for name in names:
195+
if name in auto:
196+
msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}")
197+
raise ValueError(
198+
f"shard_map {prefix}_specs argument cannot refer to an axis "
199+
f"marked auto ({auto}), but:\n\n"
200+
+ '\n\n'.join(msgs) + '\n\n'
201+
f"Check the {prefix}_specs values passed to shard_map.")
176202
raise TypeError(
177203
f"shard_map {prefix}_specs argument must be a pytree of "
178204
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
@@ -549,7 +575,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
549575
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
550576
in_avals_, in_nodes)
551577
new_axis_context = sharding_impls.SPMDAxisContext(
552-
mesh, frozenset(mesh.axis_names)
578+
mesh, frozenset(mesh.axis_names) - auto
553579
)
554580
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
555581
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
@@ -575,20 +601,20 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
575601
unspecified = set(range(aval_in.ndim)) if auto else set()
576602
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, # type: ignore
577603
unspecified_dims=unspecified)
578-
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]
604+
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)]
579605

580606
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
581607
aval_in, aval_out, xs):
582608
x, = xs
583-
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
584-
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
585609
axes = {name: i for i, ns in names.items() for name in ns}
586610
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
587611
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
588612
ns = aval_out.dtype._rules.physical_sharding(aval_out, ns)
589613
aval_out = core.physical_aval(aval_out)
590-
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
591614
unspecified = set(range(aval_out.ndim)) if auto else set()
615+
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
616+
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
617+
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
592618
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
593619
unspecified) # type: ignore
594620

tests/shard_map_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,97 @@ def f(inputs):
16251625
jtu.check_grads(f, (list(jnp.arange(float(num_args))[:,None]),), order=1,
16261626
modes=['rev'], atol=1e-3, rtol=1e-3)
16271627

1628+
def test_partial_auto(self):
1629+
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
1630+
1631+
def g(x):
1632+
x = jax.lax.with_sharding_constraint(
1633+
x, jax.sharding.NamedSharding(mesh, P(None, 'j')))
1634+
return x * x
1635+
1636+
@jax.jit
1637+
def f(x):
1638+
x = shard_map(g, mesh,
1639+
in_specs=P('i', None),
1640+
out_specs=P('i', None),
1641+
check_rep=False,
1642+
auto=frozenset({'j'}))(x)
1643+
return jax.lax.with_sharding_constraint(
1644+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1645+
1646+
v = jnp.arange(32.).reshape(4, 8)
1647+
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1648+
self.assertAllClose(v*v, f(v), check_dtypes=False)
1649+
1650+
def test_partial_auto_error_wsc_manual(self):
1651+
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
1652+
1653+
def g(x):
1654+
x = jax.lax.with_sharding_constraint(
1655+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1656+
return x * x
1657+
1658+
@jax.jit
1659+
def f(x):
1660+
x = shard_map(g, mesh,
1661+
in_specs=P('i', None),
1662+
out_specs=P('i', None),
1663+
check_rep=False,
1664+
auto=frozenset({'j'}))(x)
1665+
return jax.lax.with_sharding_constraint(
1666+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1667+
1668+
v = jnp.arange(32.).reshape(4, 8)
1669+
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1670+
with self.assertRaisesRegex(ValueError, "manual"):
1671+
f(v)
1672+
1673+
def test_partial_auto_error_invalid_auto(self):
1674+
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
1675+
1676+
def g(x):
1677+
x = jax.lax.with_sharding_constraint(
1678+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1679+
return x * x
1680+
1681+
@jax.jit
1682+
def f(x):
1683+
x = shard_map(g, mesh,
1684+
in_specs=P('i', None),
1685+
out_specs=P('i', None),
1686+
check_rep=False,
1687+
auto=frozenset({'k'}))(x)
1688+
return jax.lax.with_sharding_constraint(
1689+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1690+
1691+
v = jnp.arange(32.).reshape(4, 8)
1692+
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1693+
with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"):
1694+
f(v)
1695+
1696+
def test_partial_auto_error_wrong_in_specs(self):
1697+
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
1698+
1699+
def g(x):
1700+
x = jax.lax.with_sharding_constraint(
1701+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1702+
return x * x
1703+
1704+
@jax.jit
1705+
def f(x):
1706+
x = shard_map(g, mesh,
1707+
in_specs=P('i', 'j'),
1708+
out_specs=P('i', None),
1709+
check_rep=False,
1710+
auto=frozenset({'j'}))(x)
1711+
return jax.lax.with_sharding_constraint(
1712+
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1713+
1714+
v = jnp.arange(32.).reshape(4, 8)
1715+
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
1716+
with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"):
1717+
f(v)
1718+
16281719

16291720
class FunSpec(NamedTuple):
16301721
name: str

0 commit comments

Comments
 (0)