Skip to content

Commit 5c9c57f

Browse files
yashk2810jax authors
authored andcommitted
Allow partially specified shardings in in_shardings and out_shardings parameters of jax.jit.
PiperOrigin-RevId: 611848778
1 parent 3a89557 commit 5c9c57f

File tree

2 files changed

+20
-38
lines changed

2 files changed

+20
-38
lines changed

jax/_src/sharding_impls.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,33 +1113,13 @@ def __repr__(self):
11131113
f"sync={self.sync})")
11141114

11151115

1116-
def check_all_or_none_unspecified(axis_resources, name):
1117-
if not axis_resources:
1118-
return False
1119-
unspecified_count = 0
1120-
unspecified = is_unspecified(axis_resources[0])
1121-
for resource in axis_resources:
1122-
current_is_unspecified = is_unspecified(resource)
1123-
if current_is_unspecified:
1124-
unspecified_count += 1
1125-
assert unspecified_count == 1
1126-
if current_is_unspecified != unspecified:
1127-
raise ValueError(f'`pjit.UNSPECIFIED` exists in {name}. '
1128-
f'Make sure that every entry in {name} is '
1129-
'`pjit.UNSPECIFIED`.')
1130-
return unspecified
1131-
1132-
11331116
def prepare_axis_resources(axis_resources,
11341117
arg_name,
11351118
allow_unconstrained_dims=False):
11361119
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
11371120
entries, treedef = tree_util.tree_flatten(
11381121
axis_resources, is_leaf=lambda x: x is None)
11391122
what = f"{arg_name} leaf specifications"
1140-
# All entries should be specified or if unspecified then there should only
1141-
# be 1 entry for that since UNSPECIFIED is a private API.
1142-
check_all_or_none_unspecified(entries, arg_name)
11431123

11441124
new_entries = []
11451125
for entry in entries:

tests/pjit_test.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3881,6 +3881,26 @@ def f():
38813881
lowered_text = make_keys.lower(seeds).as_text()
38823882
self.assertIn('unspecified_dims=[0,1]', lowered_text)
38833883

3884+
def test_jit_partially_specified_shardings(self):
3885+
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
3886+
np_inp = np.arange(16).reshape(8, 2)
3887+
s = NamedSharding(mesh, P('x', 'y'))
3888+
s2 = NamedSharding(mesh, P('x'))
3889+
arr = jax.device_put(np_inp, s)
3890+
arr2 = jax.device_put(np_inp, s2)
3891+
3892+
@partial(jax.jit, in_shardings=(s, None, s2, UNSPECIFIED, UNSPECIFIED),
3893+
out_shardings=(s2, None, None, s, None))
3894+
def f(x, y, z, a, b):
3895+
return x * 2, y @ y.T, z ** 2, a * 3, b.T
3896+
3897+
out1, out2, out3, out4, out5 = f(arr, np_inp, arr2, np_inp, arr)
3898+
self.assertArraysEqual(out1, np_inp * 2)
3899+
self.assertArraysEqual(out2, np_inp @ np_inp.T)
3900+
self.assertArraysEqual(out3, np_inp ** 2)
3901+
self.assertArraysEqual(out4, np_inp * 3)
3902+
self.assertArraysEqual(out5, np_inp.T)
3903+
38843904

38853905
class TempSharding(Sharding):
38863906

@@ -4314,24 +4334,6 @@ def test_array_mapping_to_axis_resources(self, inp, expected_out):
43144334
sharding_impls.array_mapping_to_axis_resources(inp), expected_out
43154335
)
43164336

4317-
@parameterized.named_parameters(
4318-
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
4319-
("only_unspecified", UNSPECIFIED),
4320-
("all_specified", (P('x'), P('y'))),
4321-
("only_specified", P('x')),
4322-
("mix_1", (P('x'), UNSPECIFIED), ValueError),
4323-
("mix_2", (P('x'), UNSPECIFIED, P('y')), ValueError),
4324-
("mix_3", (UNSPECIFIED, P('x'), P('y')), ValueError),
4325-
("mix_4", (UNSPECIFIED, P('x'), UNSPECIFIED), ValueError),
4326-
)
4327-
def test_all_or_non_unspecified(self, axis_resources, error=None):
4328-
entries, _ = jax.tree.flatten(axis_resources, is_leaf=lambda x: x is None)
4329-
if error is not None:
4330-
with self.assertRaises(error):
4331-
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')
4332-
else:
4333-
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')
4334-
43354337
def test_op_sharding_equality_and_hash_equality(self):
43364338
op1 = xc.OpSharding()
43374339
op1.type = xc.OpSharding.Type.OTHER

0 commit comments

Comments
 (0)