Skip to content

Commit a1c8207

Browse files
yashk2810jax authors
authored andcommitted
Add kwargs support to in_shardings argument of jax.jit.
Currently, we only support this case: * If kwargs are specified, then all in_shardings should be specified as dict matching the kwargs. args and kwargs mixture is not allowed. Either everything are kwargs or args hence in_shardings is a dict or specified positionally. Example: ``` @partial(jax.jit, in_shardings=dict(y=s2, x=s1)) def f(x, y): return x * 2, y * 2 f(x=arr, y=arr2) ``` Fixes #17400 PiperOrigin-RevId: 623018032
1 parent 1b3aea8 commit a1c8207

File tree

3 files changed

+183
-27
lines changed

3 files changed

+183
-27
lines changed

jax/_src/api_util.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_
187187
inspect.Parameter.KEYWORD_ONLY,
188188
)
189189
def _validate_argnames(
190-
sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str
190+
sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str,
191+
err_on_invalid: bool = False
191192
) -> None:
192193
"""
193194
Validate that the argnames are sensible for a given function.
@@ -212,13 +213,15 @@ def _validate_argnames(
212213
# Check whether any kwargs are invalid due to position only
213214
invalid_argnames = invalid_kwargs & set(argnames)
214215
if invalid_argnames:
215-
# raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
216-
# f"in {argnames_name}. These are positional-only")
217-
# TODO: 2022-08-20 or later: replace with error
218-
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
219-
f"in {argnames_name}. These are positional-only. "
220-
"This warning will be replaced by an error after 2022-08-20 "
221-
"at the earliest.", SyntaxWarning)
216+
if err_on_invalid:
217+
raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
218+
f"in {argnames_name}. These are positional-only")
219+
else:
220+
# TODO: 2022-08-20 or later: replace with error
221+
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
222+
f"in {argnames_name}. These are positional-only. "
223+
"This warning will be replaced by an error after 2022-08-20 "
224+
"at the earliest.", SyntaxWarning)
222225

223226
# Takes any kwargs
224227
if var_kwargs:
@@ -227,13 +230,16 @@ def _validate_argnames(
227230
# Check that all argnames exist on function
228231
invalid_argnames = set(argnames) - valid_kwargs
229232
if invalid_argnames:
230-
# TODO: 2022-08-20 or later: replace with error
231-
# raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
232-
# f"in {argnames_name}. Function does not take these args.")
233-
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
234-
f"in {argnames_name}. Function does not take these args."
235-
"This warning will be replaced by an error after 2022-08-20 "
236-
"at the earliest.", SyntaxWarning)
233+
if err_on_invalid:
234+
raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
235+
f"in {argnames_name}. Function does not take these args. "
236+
f"Valid kwargs are {valid_kwargs}")
237+
else:
238+
# TODO: 2022-08-20 or later: replace with error
239+
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
240+
f"in {argnames_name}. Function does not take these args."
241+
"This warning will be replaced by an error after 2022-08-20 "
242+
"at the earliest.", SyntaxWarning)
237243

238244

239245
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):

jax/_src/pjit.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,35 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
510510
return _make_jit_wrapper(jit_info)
511511

512512

513+
def _check_in_shardings_kwargs_compatibility(
514+
kws: bool, has_args: bool, sig, in_shardings_treedef,
515+
user_specified_in_shardings: bool):
516+
if not user_specified_in_shardings:
517+
return
518+
node_data = in_shardings_treedef.node_data()
519+
if node_data is None:
520+
return
521+
type_, keys = node_data
522+
if not kws and type_ is not dict:
523+
return
524+
# TODO(yashkatariya): Try to allow these cases as and when need arises.
525+
if has_args and kws and type_ is dict:
526+
raise ValueError(
527+
'If in_shardings is a dict, the function should be called with kwargs'
528+
' only.')
529+
if kws and type_ is not dict:
530+
raise ValueError(
531+
'If you are using kwargs, in_shardings needs to be passed as a dict'
532+
f' corresponding to the kwargs. Got in_shardings type: {type_}')
533+
if not kws and type_ is dict:
534+
raise ValueError(
535+
'in_shardings can only be an instance of dict if you have kwargs.'
536+
' Please pass in_shardings positionally if you are using args.')
537+
assert type_ is dict
538+
api_util._validate_argnames(sig, tuple(keys), 'in_shardings',
539+
err_on_invalid=True)
540+
541+
513542
def _infer_params(jit_info, args, kwargs):
514543
(fun, fun_sourceinfo, fun_signature, user_specified_in_shardings,
515544
in_shardings_treedef, in_shardings_leaves, out_shardings_treedef,
@@ -519,9 +548,9 @@ def _infer_params(jit_info, args, kwargs):
519548
abstracted_axes, _, use_resource_env) = jit_info
520549

521550
have_kwargs = bool(kwargs)
522-
if have_kwargs and user_specified_in_shardings:
523-
raise ValueError(
524-
"pjit does not support kwargs when in_shardings is specified.")
551+
_check_in_shardings_kwargs_compatibility(
552+
have_kwargs, bool(args), fun_signature, in_shardings_treedef,
553+
user_specified_in_shardings)
525554

526555
if use_resource_env:
527556
# We need to fetch the mesh from inside the wrapped function, because
@@ -538,7 +567,6 @@ def _infer_params(jit_info, args, kwargs):
538567
pjit_mesh = None
539568
jit_name = 'jit'
540569

541-
542570
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
543571

544572
dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs,
@@ -998,21 +1026,36 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
9981026
in_layouts_treedef, in_layouts_leaves,
9991027
in_avals, in_tree, debug_info,
10001028
device_or_backend_set, kws):
1029+
in_tree_args, in_tree_kwargs = treedef_children(in_tree)
10011030
if not kws:
1002-
in_tree, _ = treedef_children(in_tree)
1031+
in_tree = in_tree_args
10031032

10041033
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
10051034
# Only do this if original in_shardings are unspecified. If it is AUTO, go
10061035
# via flatten_axis_resources.
10071036
if is_unspecified(orig_in_shardings):
10081037
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
1038+
elif isinstance(orig_in_shardings, dict):
1039+
if in_shardings_treedef != in_tree_kwargs:
1040+
# TODO(yashkatariya): Improve the error message drastically.
1041+
raise ValueError(
1042+
'Pytree of in_shardings and kwargs should be equal. Got in_shardings'
1043+
f' pytree: {in_shardings_treedef}, kwargs pytree: {in_tree_kwargs}')
1044+
in_shardings_flat = in_shardings_leaves
10091045
else:
10101046
in_shardings_flat = flatten_axis_resources(
10111047
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
10121048

10131049
in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves)
10141050
if in_layouts is None:
10151051
in_layouts_flat = (in_layouts,) * len(in_avals)
1052+
elif isinstance(in_layouts, dict):
1053+
if in_layouts_treedef != in_tree_kwargs:
1054+
# TODO(yashkatariya): Improve the error message drastically.
1055+
raise ValueError(
1056+
'Pytree of in_layouts and kwargs should be equal. Got in_layouts'
1057+
f' pytree: {in_layouts_treedef}, kwargs pytree: {in_tree_kwargs}')
1058+
in_layouts_flat = in_layouts_leaves
10161059
else:
10171060
in_layouts_flat = flatten_axis_resources(
10181061
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)

tests/pjit_test.py

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,13 +2822,6 @@ def f(x, y, z):
28222822
self.assertEqual(cache_info3.hits, cache_info2.hits)
28232823
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
28242824

2825-
def test_pjit_kwargs_axis_resources_error(self):
2826-
with self.assertRaisesRegex(
2827-
ValueError,
2828-
"pjit does not support kwargs when in_shardings is specified."):
2829-
pjit(lambda x: x,
2830-
in_shardings=SingleDeviceSharding(jax.devices()[0]))(x=jnp.arange(8.))
2831-
28322825
def test_pjit_keep_unused_true(self):
28332826
@partial(pjit, keep_unused=True)
28342827
def f(x, y, z, a, b, c): # pylint: disable=unused-argument
@@ -4015,6 +4008,120 @@ def f(*args):
40154008
inps = [arr, *[inp] * 2001]
40164009
f(inps) # doesn't crash
40174010

4011+
def test_in_shardings_kwargs(self):
4012+
mesh = jtu.create_global_mesh((2,), 'x')
4013+
s = NamedSharding(mesh, P('x'))
4014+
s1 = NamedSharding(mesh, P())
4015+
np_inp = np.arange(8)
4016+
arr = jax.device_put(np_inp, s1)
4017+
4018+
kw_in_s = dict(y=s1, x=s)
4019+
@partial(jax.jit, in_shardings=kw_in_s, out_shardings=(s, s1))
4020+
def f(x, y):
4021+
return x * 2, y * 2
4022+
4023+
out1, out2 = f(x=np_inp, y=arr)
4024+
self.assertArraysEqual(out1, np_inp * 2)
4025+
self.assertArraysEqual(out2, np_inp * 2)
4026+
self.assertEqual(out1.sharding, s)
4027+
self.assertEqual(out2.sharding, s1)
4028+
4029+
c_in_s = f.lower(x=np_inp, y=np_inp).compile().input_shardings
4030+
_, kw_shardings = c_in_s
4031+
self.assertDictEqual(kw_shardings, kw_in_s)
4032+
4033+
def test_kwargs_leaf_in_shardings(self):
4034+
s = SingleDeviceSharding(jax.devices()[0])
4035+
np_inp = np.arange(16).reshape(8, 2)
4036+
4037+
@partial(jax.jit, in_shardings=s)
4038+
def f(x, y):
4039+
return x @ y.T
4040+
4041+
out = f(x=np_inp, y=np_inp)
4042+
self.assertArraysEqual(out, np_inp @ np_inp.T)
4043+
4044+
def test_arg_in_shardings_kwarg_error(self):
4045+
s = SingleDeviceSharding(jax.devices()[0])
4046+
np_inp = np.arange(16).reshape(8, 2)
4047+
4048+
@partial(jax.jit, in_shardings=dict(x=s, y=s))
4049+
def f(x, y):
4050+
return x * 2, y * 2
4051+
4052+
with self.assertRaisesRegex(
4053+
ValueError,
4054+
'in_shardings can only be an instance of dict if you have kwargs.'
4055+
' Please pass in_shardings positionally if you are using args.'):
4056+
f(np_inp, np_inp)
4057+
4058+
def test_kwarg_in_shardings_positional_error(self):
4059+
s = SingleDeviceSharding(jax.devices()[0])
4060+
np_inp = np.arange(16).reshape(8, 2)
4061+
4062+
@partial(jax.jit, in_shardings=(s, s))
4063+
def f(x, y):
4064+
return x * 2, y * 2
4065+
4066+
with self.assertRaisesRegex(
4067+
ValueError,
4068+
'If you are using kwargs, in_shardings needs to be passed as a dict'
4069+
' corresponding to the kwargs'):
4070+
f(x=np_inp, y=np_inp)
4071+
4072+
def test_invalid_in_shardings_kwarg(self):
4073+
s = SingleDeviceSharding(jax.devices()[0])
4074+
np_inp = np.arange(16).reshape(8, 2)
4075+
4076+
@partial(jax.jit, in_shardings=dict(z=s))
4077+
def f(x):
4078+
return x * 2
4079+
4080+
with self.assertRaisesRegex(
4081+
ValueError,
4082+
"Jitted function has invalid argnames {'z'} in in_shardings. Function"
4083+
" does not take these args. Valid kwargs are {'x'}"):
4084+
f(x=np_inp)
4085+
4086+
def test_mix_args_kwargs_in_shardings(self):
4087+
s = SingleDeviceSharding(jax.devices()[0])
4088+
np_inp = np.arange(16).reshape(8, 2)
4089+
4090+
@partial(jax.jit, in_shardings=dict(x=s, y=s, z=s))
4091+
def f(x, y, z):
4092+
return x * 2, y * 2, z * 2
4093+
4094+
with self.assertRaisesRegex(
4095+
ValueError,
4096+
'If in_shardings is a dict, the function should be called with kwargs'
4097+
' only.'):
4098+
f(np_inp, y=np_inp, z=np_inp)
4099+
4100+
def test_kwargs_in_shardings_partial(self):
4101+
s = SingleDeviceSharding(jax.devices()[0])
4102+
np_inp = np.arange(16).reshape(8, 2)
4103+
4104+
@partial(jax.jit, in_shardings=dict(x=s))
4105+
def f(x, y, z):
4106+
return x * 2, y * 2, z * 2
4107+
4108+
with self.assertRaisesRegex(
4109+
ValueError, "Pytree of in_shardings and kwargs should be equal"):
4110+
f(x=np_inp, y=np_inp, z=np_inp)
4111+
4112+
def test_args_kwargs_in_shardings_mixture(self):
4113+
s = SingleDeviceSharding(jax.devices()[0])
4114+
np_inp = np.arange(16).reshape(8, 2)
4115+
4116+
@partial(jax.jit, in_shardings=(s, dict(y=s, z=s)))
4117+
def f(x, y, z):
4118+
return x * 2, y * 2, z * 2
4119+
4120+
with self.assertRaisesRegex(
4121+
ValueError,
4122+
'If you are using kwargs, in_shardings needs to be passed as a dict'):
4123+
f(np_inp, y=np_inp, z=np_inp)
4124+
40184125

40194126
class TempSharding(Sharding):
40204127

0 commit comments

Comments
 (0)