Skip to content

Commit cd1e55a

Browse files
yashk2810jax authors
authored andcommitted
Remove physical_hlo_sharding from TyRules.
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings. PiperOrigin-RevId: 616267810
1 parent 9cf2fbe commit cd1e55a

File tree

6 files changed

+35
-64
lines changed

6 files changed

+35
-64
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -819,16 +819,19 @@ class LoweringResult(NamedTuple):
819819
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
820820

821821

822-
def _to_logical_op_sharding(
822+
def _to_physical_op_sharding(
823823
aval: core.AbstractValue, sharding: XLACompatibleSharding | None,
824-
) -> xc.HloSharding | None:
824+
) -> xc.OpSharding | None:
825825
if sharding is None:
826826
return None
827827
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
828828
if isinstance(aval, AbstractRef):
829-
return _to_logical_op_sharding(aval.inner_aval, sharding)
829+
return _to_physical_op_sharding(aval.inner_aval, sharding)
830830
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
831-
return sharding._to_xla_hlo_sharding(aval.ndim)
831+
if dtypes.issubdtype(aval.dtype, dtypes.extended):
832+
sharding = aval.dtype._rules.physical_sharding(aval, sharding)
833+
aval = core.physical_aval(aval)
834+
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
832835

833836

834837
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
@@ -941,13 +944,6 @@ def lower_jaxpr_to_module(
941944
else:
942945
dim_vars = ()
943946

944-
arg_op_shardings = (
945-
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
946-
if arg_shardings is not None else arg_shardings)
947-
result_op_shardings = (
948-
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
949-
if result_shardings is not None else result_shardings)
950-
951947
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
952948
else in_layouts)
953949
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
@@ -978,8 +974,8 @@ def lower_jaxpr_to_module(
978974
replace_tokens_with_dummy=replace_tokens_with_dummy,
979975
num_output_tokens=0,
980976
replicated_args=replicated_args,
981-
arg_shardings=arg_op_shardings,
982-
result_shardings=result_op_shardings,
977+
arg_shardings=arg_shardings,
978+
result_shardings=result_shardings,
983979
input_output_aliases=input_output_aliases,
984980
xla_donated_args=xla_donated_args,
985981
arg_names=arg_names,
@@ -1123,8 +1119,8 @@ def lower_jaxpr_to_fun(
11231119
public: bool = False,
11241120
replace_tokens_with_dummy: bool = False,
11251121
replicated_args: Sequence[bool] | None = None,
1126-
arg_shardings: Sequence[xc.HloSharding | None] | None = None,
1127-
result_shardings: Sequence[xc.HloSharding | None] | None = None,
1122+
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
1123+
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
11281124
use_sharding_annotations: bool = True,
11291125
input_output_aliases: Sequence[int | None] | None = None,
11301126
xla_donated_args: Sequence[bool] | None = None,
@@ -1483,15 +1479,6 @@ def wrap_with_memory_kind(
14831479
return op.result
14841480

14851481

1486-
def _to_physical_op_sharding(
1487-
aval: core.AbstractValue | None, sharding: xc.HloSharding | None
1488-
) -> xc.OpSharding | None:
1489-
if (isinstance(aval, core.ShapedArray) and dtypes.issubdtype(aval.dtype, dtypes.extended)
1490-
and sharding is not None):
1491-
return aval.dtype._rules.physical_hlo_sharding(aval, sharding).to_proto()
1492-
return None if sharding is None else sharding.to_proto() # type: ignore
1493-
1494-
14951482
def _emit_lowering_rule_as_fun(lowering_rule,
14961483
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
14971484
"""Emits the contents of a lowering rule as a private function."""

jax/_src/lax/lax.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5110,10 +5110,6 @@ def handler(bufs):
51105110
return core.DArray(aval, phys_handler(bufs))
51115111
return handler
51125112

5113-
@staticmethod
5114-
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
5115-
return hlo_sharding
5116-
51175113
@staticmethod
51185114
def logical_sharding(aval, phys_sharding):
51195115
return phys_sharding

jax/_src/pjit.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,10 +1617,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
16171617

16181618
func = mod_ctx.cached_primitive_lowerings.get(key, None)
16191619
if func is None:
1620-
arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim)
1621-
for aval, i in zip(ctx.avals_in, in_shardings)]
1622-
result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim)
1623-
for aval, o in zip(ctx.avals_out, out_shardings)]
1620+
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
1621+
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
16241622
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
16251623
# inputs or outputs because they are lost during MLIR->HLO conversion.
16261624
# using_sharding_annotation=False means we add an identity operation instead.

jax/_src/prng.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ def make_key_array_phys_sharding(aval, sharding):
342342
else:
343343
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
344344
return GSPMDSharding(
345-
sharding._device_assignment,
346-
KeyTyRules.physical_hlo_sharding(aval, hlos))
345+
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
347346

348347

349348
def get_logical_gspmd_sharding(aval, phys_sharding):
@@ -361,6 +360,17 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
361360
xc.HloSharding.from_proto(logical_op_sharding))
362361

363362

363+
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
364+
key_shape = aval.dtype._impl.key_shape
365+
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
366+
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
367+
hlo_sharding)
368+
suffix = [] if num_replicas == 1 else [num_replicas]
369+
tad = partitions + [1] * len(key_shape) + suffix
370+
new_op_sharding.tile_assignment_dimensions = tad
371+
return xc.HloSharding.from_proto(new_op_sharding)
372+
373+
364374
class KeyTyRules:
365375

366376
@staticmethod
@@ -382,17 +392,6 @@ def physical_element_aval(dtype) -> core.ShapedArray:
382392
def physical_const(val) -> Array:
383393
return val._base_array
384394

385-
@staticmethod
386-
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
387-
key_shape = aval.dtype._impl.key_shape
388-
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
389-
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
390-
hlo_sharding)
391-
suffix = [] if num_replicas == 1 else [num_replicas]
392-
tad = partitions + [1] * len(key_shape) + suffix
393-
new_op_sharding.tile_assignment_dimensions = tad
394-
return xc.HloSharding.from_proto(new_op_sharding)
395-
396395
@staticmethod
397396
def physical_sharding(
398397
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:

jax/experimental/shard_map.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,13 +567,13 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
567567
aval_in, aval_out, x):
568568
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
569569
axes = {name: i for i, ns in names.items() for name in ns}
570-
shard_proto = NamedSharding(
571-
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
572-
)._to_xla_hlo_sharding(aval_in.ndim)
570+
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
573571
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
574-
shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto)
572+
ns = aval_in.dtype._rules.physical_sharding(aval_in, ns)
573+
aval_in = core.physical_aval(aval_in)
574+
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
575575
unspecified = set(range(aval_in.ndim)) if auto else set()
576-
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore
576+
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, # type: ignore
577577
unspecified_dims=unspecified)
578578
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]
579579

@@ -583,13 +583,13 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
583583
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
584584
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
585585
axes = {name: i for i, ns in names.items() for name in ns}
586-
shard_proto = NamedSharding(
587-
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
588-
)._to_xla_hlo_sharding(aval_out.ndim)
586+
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
589587
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
590-
shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto)
588+
ns = aval_out.dtype._rules.physical_sharding(aval_out, ns)
589+
aval_out = core.physical_aval(aval_out)
590+
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
591591
unspecified = set(range(aval_out.ndim)) if auto else set()
592-
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto.to_proto(),
592+
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
593593
unspecified) # type: ignore
594594

595595
def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str:

tests/lax_test.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from jax._src.interpreters import pxla
4646
from jax._src.internal_test_util import lax_test_util
4747
from jax._src.lax import lax as lax_internal
48-
from jax._src.lib import xla_client as xc
4948
from jax._src.lib import xla_extension_version
5049
from jax._src.util import NumpyComplexWarning
5150

@@ -2989,14 +2988,6 @@ class FooTyRules:
29892988
def physical_element_aval(dtype) -> core.ShapedArray:
29902989
return core.ShapedArray((2,), jnp.dtype('uint32'))
29912990

2992-
@staticmethod
2993-
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding):
2994-
op_sharding_proto = hlo_sharding.to_proto()
2995-
new_op_sharding = op_sharding_proto.clone()
2996-
tad = list(new_op_sharding.tile_assignment_dimensions)
2997-
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
2998-
return xc.HloSharding.from_proto(new_op_sharding)
2999-
30002991
@staticmethod
30012992
def logical_sharding(aval, phys_sharding):
30022993
return phys_sharding

0 commit comments

Comments
 (0)