Skip to content

Commit 52f7de0

Browse files
yashk2810jax authors
authored andcommitted
Remove the unused return from prepare_axis_resources
PiperOrigin-RevId: 621738698
1 parent bc0eff5 commit 52f7de0

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3286,7 +3286,7 @@ def check_array_xla_sharding_layout_match(
32863286

32873287

32883288
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
3289-
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
3289+
parsed_pspec = sharding_impls.prepare_axis_resources(
32903290
pspec, "pspec to array_mapping")
32913291
return _get_array_mapping(parsed_pspec)
32923292

jax/_src/pjit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
378378
# rather than raising an error. https://github.com/google/jax/issues/2367
379379
in_shardings = tuple(in_shardings)
380380

381-
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
382-
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
381+
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
382+
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
383383

384384
user_specified_in_shardings = (in_shardings is not None and
385385
not is_unspecified(in_shardings))
@@ -2163,7 +2163,7 @@ def with_sharding_constraint(x, shardings):
21632163
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
21642164
"""
21652165
x_flat, tree = tree_flatten(x)
2166-
user_shardings, _, _ = prepare_axis_resources(
2166+
user_shardings = prepare_axis_resources(
21672167
shardings, "shardings", allow_unconstrained_dims=True)
21682168
del shardings
21692169

jax/_src/sharding_impls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,7 +1115,7 @@ def preprocess(mesh, spec, parsed_pspec):
11151115
# TODO(yaskatariya): Remove this and replace this with a normalized
11161116
# representation of Parsed Pspec
11171117
if parsed_pspec is None:
1118-
parsed_pspec, _, _ = prepare_axis_resources(
1118+
parsed_pspec = prepare_axis_resources(
11191119
PartitionSpec() if spec is None else spec,
11201120
"NamedSharding spec", allow_unconstrained_dims=True)
11211121

@@ -1148,7 +1148,7 @@ def prepare_axis_resources(axis_resources,
11481148
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
11491149

11501150
_check_unique_resources(new_entries, arg_name)
1151-
return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef
1151+
return tree_util.tree_unflatten(treedef, new_entries)
11521152

11531153

11541154
def _check_unique_resources(axis_resources, arg_name):

0 commit comments

Comments
 (0)