Skip to content

Commit c42a035

Browse files
yashk2810jax authors
authored andcommitted
Let XLA choose in_shardings for inputs who sharding is unspecified.
This is a strict improvement over the current state where JAX always chooses replicated sharding. PiperOrigin-RevId: 610771289
1 parent 98e4b9e commit c42a035

File tree

5 files changed

+132
-49
lines changed

5 files changed

+132
-49
lines changed

jax/_src/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def make_array_from_single_device_arrays(
766766
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
767767
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
768768
769-
When using multiple processes, a common data pipeling is to have data parallelism across devices,
769+
When using multiple processes, a common data pipeline is to have data parallelism across devices,
770770
with each device receiving at least one example. In this case, the following recipe will use
771771
`make_array_from_single_device_arrays` to create a global jax.Array.
772772

jax/_src/interpreters/pxla.py

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,8 +2006,15 @@ def lower_sharding_computation(
20062006
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
20072007
any(not is_unspecified(o) for o in out_shardings))
20082008

2009-
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
2010-
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
2009+
gs = GSPMDSharding.get_replicated(device_assignment)
2010+
if xla_extension_version < 240 or hasattr(backend, "compile_replicated"):
2011+
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
2012+
2013+
# TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after
2014+
# output sharding of XLA is partially constrained on the trailing dimensions.
2015+
in_shardings = tuple(
2016+
gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
2017+
else i for i, a in safe_zip(in_shardings, global_in_avals))
20112018

20122019
da_object = _create_da_object(tuple(device_assignment))
20132020

@@ -2318,7 +2325,7 @@ def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
23182325
return input_indices
23192326

23202327

2321-
def get_gspmd_shardings_from_executable(
2328+
def get_out_shardings_from_executable(
23222329
xla_executable,
23232330
device_assignment: Sequence[xc.Device],
23242331
num_out_avals: int,
@@ -2374,6 +2381,32 @@ def get_gspmd_shardings_from_executable(
23742381
for os, mk in safe_zip(out_op_shardings, omk)]
23752382

23762383

2384+
def _get_in_shardings_from_xla(
2385+
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
2386+
num_ordered_effects: int
2387+
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
2388+
"""Returns input shardings from XLA."""
2389+
from jax._src import pjit
2390+
2391+
# When the device assignment only has 1 device, SPMD partitioner will not run.
2392+
# Hence the op shardings will not be set on the `hlo_module`.
2393+
if len(device_assignment) == 1:
2394+
return [sharding_impls.SingleDeviceSharding(device_assignment[0])] * num_in_avals
2395+
2396+
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
2397+
if not in_op_shardings:
2398+
return None
2399+
2400+
if num_ordered_effects > 0:
2401+
in_op_shardings = in_op_shardings[num_ordered_effects:]
2402+
2403+
assert len(in_op_shardings) == num_in_avals, (
2404+
len(in_op_shardings), num_in_avals)
2405+
2406+
return [sharding_impls.GSPMDSharding(device_assignment, os)
2407+
for os in in_op_shardings]
2408+
2409+
23772410
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
23782411
# without mesh.
23792412
def _get_mesh_pspec_shardings_from_executable(
@@ -2526,8 +2559,8 @@ def get_logical_mesh_ids(mesh_shape):
25262559

25272560
@weakref_lru_cache
25282561
def _cached_compilation(computation, name, mesh, spmd_lowering,
2529-
tuple_args, auto_spmd_lowering,
2530-
_allow_propagation_to_outputs, host_callbacks, backend,
2562+
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
2563+
allow_prop_to_outputs, host_callbacks, backend,
25312564
da, pmap_nreps, compiler_options_keys,
25322565
compiler_options_values):
25332566
# TODO(phawkins): One would normally just write:
@@ -2580,7 +2613,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
25802613
get_logical_mesh_ids(list(mesh.shape.values()))
25812614
.reshape(-1))
25822615
compile_options.parameter_is_tupled_arguments = tuple_args
2583-
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
2616+
if xla_extension_version >= 240:
2617+
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
2618+
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
25842619

25852620
if hasattr(backend, "compile_replicated"):
25862621
return None, compile_options
@@ -2593,22 +2628,59 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
25932628
return xla_executable, compile_options
25942629

25952630

2596-
def _get_shardings_from_executable(
2631+
def _maybe_get_and_check_in_shardings(
2632+
xla_executable, in_shardings, device_assignment,
2633+
global_in_avals, num_ordered_effects):
2634+
"""Returns in_shardings extracted from XLA or checks and returns original
2635+
shardings.
2636+
2637+
If in_shardings exist on `jit` or on `jax.Array`, then this function will
2638+
check that sharding against what XLA returns as in_shardings. If they don't
2639+
match, an error is raised.
2640+
2641+
If in_sharding is unspecified, then the sharding returned by XLA is returned.
2642+
"""
2643+
in_shardings_xla = _get_in_shardings_from_xla( # type: ignore
2644+
xla_executable, device_assignment, len(global_in_avals),
2645+
num_ordered_effects) # type: ignore
2646+
if in_shardings_xla is None:
2647+
return in_shardings
2648+
2649+
new_in_shardings = []
2650+
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
2651+
global_in_avals):
2652+
if is_unspecified(orig):
2653+
new_in_shardings.append(xla_s)
2654+
else:
2655+
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
2656+
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
2657+
# MANUAL HloSharding comes from other partitioning frameworks.
2658+
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
2659+
not xla_hlo_s.is_manual() and
2660+
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
2661+
xla_s.memory_kind != orig.memory_kind)): # type: ignore
2662+
raise AssertionError(
2663+
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
2664+
"(User sharding)")
2665+
new_in_shardings.append(orig)
2666+
return new_in_shardings
2667+
2668+
2669+
def _get_out_shardings_from_executable(
25972670
xla_executable, out_shardings, device_assignment, global_out_avals,
25982671
num_ordered_effects, all_default_mem_kind
25992672
):
2600-
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
2673+
out_shardings_xla = get_out_shardings_from_executable( # type: ignore
26012674
xla_executable, device_assignment, len(global_out_avals),
26022675
num_ordered_effects, all_default_mem_kind) # type: ignore
26032676
if out_shardings_xla is None:
26042677
return out_shardings, (False,) * len(global_out_avals)
26052678

2606-
orig_out_shardings = out_shardings
2607-
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
2608-
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
2679+
new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
2680+
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
26092681
global_out_avals):
26102682
if is_unspecified(orig):
2611-
out_shardings.append(xla_s)
2683+
new_out_shardings.append(xla_s)
26122684
are_out_shardings_from_xla.append(True)
26132685
else:
26142686
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
@@ -2621,9 +2693,9 @@ def _get_shardings_from_executable(
26212693
raise AssertionError(
26222694
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
26232695
"(User sharding)")
2624-
out_shardings.append(orig)
2696+
new_out_shardings.append(orig)
26252697
are_out_shardings_from_xla.append(False)
2626-
return out_shardings, are_out_shardings_from_xla
2698+
return new_out_shardings, are_out_shardings_from_xla
26272699

26282700

26292701
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
@@ -2722,6 +2794,8 @@ def from_hlo(name: str,
27222794
else:
27232795
da = _create_da_object(tuple(device_assignment))
27242796
del device_assignment
2797+
2798+
allow_prop_to_inputs = tuple(is_unspecified(i) for i in in_shardings)
27252799
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
27262800

27272801
mesh = None
@@ -2733,8 +2807,8 @@ def from_hlo(name: str,
27332807

27342808
xla_executable, compile_options = _cached_compilation(
27352809
hlo, name, mesh, spmd_lowering,
2736-
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
2737-
tuple(host_callbacks), backend, da, pmap_nreps,
2810+
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
2811+
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
27382812
compiler_options_keys, compiler_options_values)
27392813

27402814
if hasattr(backend, "compile_replicated"):
@@ -2761,9 +2835,11 @@ def from_hlo(name: str,
27612835
else:
27622836
if pmap_nreps == 1:
27632837
assert mesh is None
2764-
# TODO(yashkatariya): Make da directly usable in the downstream code
2765-
# without tuple conversion.
2766-
out_shardings, are_out_shardings_from_xla = _get_shardings_from_executable(
2838+
if xla_extension_version >= 240:
2839+
in_shardings = _maybe_get_and_check_in_shardings(
2840+
xla_executable, in_shardings, tuple(da), global_in_avals,
2841+
len(ordered_effects))
2842+
out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable(
27672843
xla_executable, out_shardings, tuple(da), global_out_avals,
27682844
len(ordered_effects), all_default_mem_kind)
27692845
else:

jax/experimental/jax2tf/tests/sharding_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,19 @@ def f_tf(x):
247247
# The argument
248248
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]",
249249
count_in_P),
250-
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
251-
count_in_replicated),
252250
# The result
253251
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
254252
count_out_P),
253+
])
254+
# TODO(b/326476605): Change the condition below if required.
255+
if in_shardings not in [None, "missing"] and out_shardings is not None:
256+
self.check_sharding(
257+
jax2tf.convert(f_jax), [x],
258+
checks=[
259+
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
260+
count_in_replicated),
255261
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated",
256262
count_out_replicated),
257-
# No other shardings
258263
(r"custom_call_target.*Sharding",
259264
count_in_P + count_in_replicated + count_out_P + count_out_replicated),
260265
])
@@ -437,10 +442,16 @@ def f_grad_tf(x_v, res_ct):
437442
checks=[
438443
# The input primal argument, and the output grad
439444
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
445+
# The primal result, and the input cotangent
446+
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
447+
])
448+
# TODO(b/326476605): Change the condition below if required.
449+
if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]:
450+
self.check_sharding(f_grad_tf, [x, x.T],
451+
checks=[
440452
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
441453
# The primal result, and the input cotangent
442454
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
443-
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
444455
])
445456

446457
@jtu.parameterized_filterable(

tests/export_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,6 @@ def f_jax(b): # b: f32[2, 4]
916916
res_r.addressable_shards[i].data)
917917

918918
@jtu.parameterized_filterable(
919-
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
920919
kwargs=[
921920
dict(in_shardings=in_shardings, out_shardings=out_shardings,
922921
with_mesh=with_mesh)
@@ -971,15 +970,17 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]
971970
else:
972971
primal_out_sharding = "{replicated}"
973972

974-
main = re.compile(
975-
r"func.func public @main\(%arg0: tensor<10x20xf32>.*"
976-
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\""
977-
r".*%arg1: tensor<20x10xf32>.*"
978-
"mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\""
979-
# result
980-
r".*->.*\(tensor<10x20xf32>.*"
981-
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"")
982-
self.assertRegex(vjp_module_str, main)
973+
# TODO(b/326476605): Change the condition below if required.
974+
if in_shardings == "P":
975+
main = re.compile(
976+
r"func.func public @main\(%arg0: tensor<10x20xf32>.*"
977+
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\""
978+
r".*%arg1: tensor<20x10xf32>.*"
979+
"mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\""
980+
# result
981+
r".*->.*\(tensor<10x20xf32>.*"
982+
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"")
983+
self.assertRegex(vjp_module_str, main)
983984

984985
# Custom calls for the primal input shape all match primal_in_sharding
985986
primal_in_calls = re.findall(

tests/pjit_test.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,6 @@ def f(x):
559559
# Annotations from with_sharding_constraint
560560
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
561561
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
562-
# Annotation from pjit
563-
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
564562

565563
def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self):
566564

@@ -1718,19 +1716,16 @@ def test_numpy_array_input_assume_fully_replicated(self):
17181716
input_shape = (8, 2)
17191717
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
17201718
input_data = np.arange(
1721-
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
1722-
with global_mesh:
1723-
f = pjit(lambda x: x,
1724-
out_shardings=NamedSharding(
1725-
global_mesh, P('x', 'y')))
1726-
# Since no in_axis_resources is provided, pjit will assume that
1727-
# the numpy input is fully replicated over the mesh.
1728-
out = f(input_data)
1729-
self.assertIsInstance(out, array.ArrayImpl)
1730-
for s in out.addressable_shards:
1731-
self.assertEqual(s.data.shape, (2, 1))
1732-
self.assertArraysEqual(s.data, input_data[s.index])
1733-
self.assertArraysEqual(out._value, input_data)
1719+
math.prod(input_shape)).reshape(input_shape)
1720+
1721+
f = pjit(lambda x: x,
1722+
out_shardings=NamedSharding(global_mesh, P('x', 'y')))
1723+
out = f(input_data)
1724+
self.assertIsInstance(out, array.ArrayImpl)
1725+
self.assertArraysEqual(out, input_data)
1726+
for s in out.addressable_shards:
1727+
self.assertEqual(s.data.shape, (2, 1))
1728+
self.assertArraysEqual(s.data, input_data[s.index])
17341729

17351730
def test_numpy_array_input(self):
17361731
input_shape = (8, 2)

0 commit comments

Comments
 (0)