Skip to content

Commit d1b1d0b

Browse files
yashk2810jax authors
authored andcommitted
Reverts a1c8207
PiperOrigin-RevId: 623045488
1 parent a1c8207 commit d1b1d0b

File tree

3 files changed

+27
-183
lines changed

3 files changed

+27
-183
lines changed

jax/_src/api_util.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_
187187
inspect.Parameter.KEYWORD_ONLY,
188188
)
189189
def _validate_argnames(
190-
sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str,
191-
err_on_invalid: bool = False
190+
sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str
192191
) -> None:
193192
"""
194193
Validate that the argnames are sensible for a given function.
@@ -213,15 +212,13 @@ def _validate_argnames(
213212
# Check whether any kwargs are invalid due to position only
214213
invalid_argnames = invalid_kwargs & set(argnames)
215214
if invalid_argnames:
216-
if err_on_invalid:
217-
raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
218-
f"in {argnames_name}. These are positional-only")
219-
else:
220-
# TODO: 2022-08-20 or later: replace with error
221-
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
222-
f"in {argnames_name}. These are positional-only. "
223-
"This warning will be replaced by an error after 2022-08-20 "
224-
"at the earliest.", SyntaxWarning)
215+
# raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
216+
# f"in {argnames_name}. These are positional-only")
217+
# TODO: 2022-08-20 or later: replace with error
218+
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
219+
f"in {argnames_name}. These are positional-only. "
220+
"This warning will be replaced by an error after 2022-08-20 "
221+
"at the earliest.", SyntaxWarning)
225222

226223
# Takes any kwargs
227224
if var_kwargs:
@@ -230,16 +227,13 @@ def _validate_argnames(
230227
# Check that all argnames exist on function
231228
invalid_argnames = set(argnames) - valid_kwargs
232229
if invalid_argnames:
233-
if err_on_invalid:
234-
raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
235-
f"in {argnames_name}. Function does not take these args. "
236-
f"Valid kwargs are {valid_kwargs}")
237-
else:
238-
# TODO: 2022-08-20 or later: replace with error
239-
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
240-
f"in {argnames_name}. Function does not take these args."
241-
"This warning will be replaced by an error after 2022-08-20 "
242-
"at the earliest.", SyntaxWarning)
230+
# TODO: 2022-08-20 or later: replace with error
231+
# raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
232+
# f"in {argnames_name}. Function does not take these args.")
233+
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
234+
f"in {argnames_name}. Function does not take these args."
235+
"This warning will be replaced by an error after 2022-08-20 "
236+
"at the earliest.", SyntaxWarning)
243237

244238

245239
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):

jax/_src/pjit.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -510,35 +510,6 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
510510
return _make_jit_wrapper(jit_info)
511511

512512

513-
def _check_in_shardings_kwargs_compatibility(
514-
kws: bool, has_args: bool, sig, in_shardings_treedef,
515-
user_specified_in_shardings: bool):
516-
if not user_specified_in_shardings:
517-
return
518-
node_data = in_shardings_treedef.node_data()
519-
if node_data is None:
520-
return
521-
type_, keys = node_data
522-
if not kws and type_ is not dict:
523-
return
524-
# TODO(yashkatariya): Try to allow these cases as and when need arises.
525-
if has_args and kws and type_ is dict:
526-
raise ValueError(
527-
'If in_shardings is a dict, the function should be called with kwargs'
528-
' only.')
529-
if kws and type_ is not dict:
530-
raise ValueError(
531-
'If you are using kwargs, in_shardings needs to be passed as a dict'
532-
f' corresponding to the kwargs. Got in_shardings type: {type_}')
533-
if not kws and type_ is dict:
534-
raise ValueError(
535-
'in_shardings can only be an instance of dict if you have kwargs.'
536-
' Please pass in_shardings positionally if you are using args.')
537-
assert type_ is dict
538-
api_util._validate_argnames(sig, tuple(keys), 'in_shardings',
539-
err_on_invalid=True)
540-
541-
542513
def _infer_params(jit_info, args, kwargs):
543514
(fun, fun_sourceinfo, fun_signature, user_specified_in_shardings,
544515
in_shardings_treedef, in_shardings_leaves, out_shardings_treedef,
@@ -548,9 +519,9 @@ def _infer_params(jit_info, args, kwargs):
548519
abstracted_axes, _, use_resource_env) = jit_info
549520

550521
have_kwargs = bool(kwargs)
551-
_check_in_shardings_kwargs_compatibility(
552-
have_kwargs, bool(args), fun_signature, in_shardings_treedef,
553-
user_specified_in_shardings)
522+
if have_kwargs and user_specified_in_shardings:
523+
raise ValueError(
524+
"pjit does not support kwargs when in_shardings is specified.")
554525

555526
if use_resource_env:
556527
# We need to fetch the mesh from inside the wrapped function, because
@@ -567,6 +538,7 @@ def _infer_params(jit_info, args, kwargs):
567538
pjit_mesh = None
568539
jit_name = 'jit'
569540

541+
570542
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
571543

572544
dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs,
@@ -1026,36 +998,21 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
1026998
in_layouts_treedef, in_layouts_leaves,
1027999
in_avals, in_tree, debug_info,
10281000
device_or_backend_set, kws):
1029-
in_tree_args, in_tree_kwargs = treedef_children(in_tree)
10301001
if not kws:
1031-
in_tree = in_tree_args
1002+
in_tree, _ = treedef_children(in_tree)
10321003

10331004
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
10341005
# Only do this if original in_shardings are unspecified. If it is AUTO, go
10351006
# via flatten_axis_resources.
10361007
if is_unspecified(orig_in_shardings):
10371008
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
1038-
elif isinstance(orig_in_shardings, dict):
1039-
if in_shardings_treedef != in_tree_kwargs:
1040-
# TODO(yashkatariya): Improve the error message drastically.
1041-
raise ValueError(
1042-
'Pytree of in_shardings and kwargs should be equal. Got in_shardings'
1043-
f' pytree: {in_shardings_treedef}, kwargs pytree: {in_tree_kwargs}')
1044-
in_shardings_flat = in_shardings_leaves
10451009
else:
10461010
in_shardings_flat = flatten_axis_resources(
10471011
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
10481012

10491013
in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves)
10501014
if in_layouts is None:
10511015
in_layouts_flat = (in_layouts,) * len(in_avals)
1052-
elif isinstance(in_layouts, dict):
1053-
if in_layouts_treedef != in_tree_kwargs:
1054-
# TODO(yashkatariya): Improve the error message drastically.
1055-
raise ValueError(
1056-
'Pytree of in_layouts and kwargs should be equal. Got in_layouts'
1057-
f' pytree: {in_layouts_treedef}, kwargs pytree: {in_tree_kwargs}')
1058-
in_layouts_flat = in_layouts_leaves
10591016
else:
10601017
in_layouts_flat = flatten_axis_resources(
10611018
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)

tests/pjit_test.py

Lines changed: 7 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,6 +2822,13 @@ def f(x, y, z):
28222822
self.assertEqual(cache_info3.hits, cache_info2.hits)
28232823
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
28242824

2825+
def test_pjit_kwargs_axis_resources_error(self):
2826+
with self.assertRaisesRegex(
2827+
ValueError,
2828+
"pjit does not support kwargs when in_shardings is specified."):
2829+
pjit(lambda x: x,
2830+
in_shardings=SingleDeviceSharding(jax.devices()[0]))(x=jnp.arange(8.))
2831+
28252832
def test_pjit_keep_unused_true(self):
28262833
@partial(pjit, keep_unused=True)
28272834
def f(x, y, z, a, b, c): # pylint: disable=unused-argument
@@ -4008,120 +4015,6 @@ def f(*args):
40084015
inps = [arr, *[inp] * 2001]
40094016
f(inps) # doesn't crash
40104017

4011-
def test_in_shardings_kwargs(self):
4012-
mesh = jtu.create_global_mesh((2,), 'x')
4013-
s = NamedSharding(mesh, P('x'))
4014-
s1 = NamedSharding(mesh, P())
4015-
np_inp = np.arange(8)
4016-
arr = jax.device_put(np_inp, s1)
4017-
4018-
kw_in_s = dict(y=s1, x=s)
4019-
@partial(jax.jit, in_shardings=kw_in_s, out_shardings=(s, s1))
4020-
def f(x, y):
4021-
return x * 2, y * 2
4022-
4023-
out1, out2 = f(x=np_inp, y=arr)
4024-
self.assertArraysEqual(out1, np_inp * 2)
4025-
self.assertArraysEqual(out2, np_inp * 2)
4026-
self.assertEqual(out1.sharding, s)
4027-
self.assertEqual(out2.sharding, s1)
4028-
4029-
c_in_s = f.lower(x=np_inp, y=np_inp).compile().input_shardings
4030-
_, kw_shardings = c_in_s
4031-
self.assertDictEqual(kw_shardings, kw_in_s)
4032-
4033-
def test_kwargs_leaf_in_shardings(self):
4034-
s = SingleDeviceSharding(jax.devices()[0])
4035-
np_inp = np.arange(16).reshape(8, 2)
4036-
4037-
@partial(jax.jit, in_shardings=s)
4038-
def f(x, y):
4039-
return x @ y.T
4040-
4041-
out = f(x=np_inp, y=np_inp)
4042-
self.assertArraysEqual(out, np_inp @ np_inp.T)
4043-
4044-
def test_arg_in_shardings_kwarg_error(self):
4045-
s = SingleDeviceSharding(jax.devices()[0])
4046-
np_inp = np.arange(16).reshape(8, 2)
4047-
4048-
@partial(jax.jit, in_shardings=dict(x=s, y=s))
4049-
def f(x, y):
4050-
return x * 2, y * 2
4051-
4052-
with self.assertRaisesRegex(
4053-
ValueError,
4054-
'in_shardings can only be an instance of dict if you have kwargs.'
4055-
' Please pass in_shardings positionally if you are using args.'):
4056-
f(np_inp, np_inp)
4057-
4058-
def test_kwarg_in_shardings_positional_error(self):
4059-
s = SingleDeviceSharding(jax.devices()[0])
4060-
np_inp = np.arange(16).reshape(8, 2)
4061-
4062-
@partial(jax.jit, in_shardings=(s, s))
4063-
def f(x, y):
4064-
return x * 2, y * 2
4065-
4066-
with self.assertRaisesRegex(
4067-
ValueError,
4068-
'If you are using kwargs, in_shardings needs to be passed as a dict'
4069-
' corresponding to the kwargs'):
4070-
f(x=np_inp, y=np_inp)
4071-
4072-
def test_invalid_in_shardings_kwarg(self):
4073-
s = SingleDeviceSharding(jax.devices()[0])
4074-
np_inp = np.arange(16).reshape(8, 2)
4075-
4076-
@partial(jax.jit, in_shardings=dict(z=s))
4077-
def f(x):
4078-
return x * 2
4079-
4080-
with self.assertRaisesRegex(
4081-
ValueError,
4082-
"Jitted function has invalid argnames {'z'} in in_shardings. Function"
4083-
" does not take these args. Valid kwargs are {'x'}"):
4084-
f(x=np_inp)
4085-
4086-
def test_mix_args_kwargs_in_shardings(self):
4087-
s = SingleDeviceSharding(jax.devices()[0])
4088-
np_inp = np.arange(16).reshape(8, 2)
4089-
4090-
@partial(jax.jit, in_shardings=dict(x=s, y=s, z=s))
4091-
def f(x, y, z):
4092-
return x * 2, y * 2, z * 2
4093-
4094-
with self.assertRaisesRegex(
4095-
ValueError,
4096-
'If in_shardings is a dict, the function should be called with kwargs'
4097-
' only.'):
4098-
f(np_inp, y=np_inp, z=np_inp)
4099-
4100-
def test_kwargs_in_shardings_partial(self):
4101-
s = SingleDeviceSharding(jax.devices()[0])
4102-
np_inp = np.arange(16).reshape(8, 2)
4103-
4104-
@partial(jax.jit, in_shardings=dict(x=s))
4105-
def f(x, y, z):
4106-
return x * 2, y * 2, z * 2
4107-
4108-
with self.assertRaisesRegex(
4109-
ValueError, "Pytree of in_shardings and kwargs should be equal"):
4110-
f(x=np_inp, y=np_inp, z=np_inp)
4111-
4112-
def test_args_kwargs_in_shardings_mixture(self):
4113-
s = SingleDeviceSharding(jax.devices()[0])
4114-
np_inp = np.arange(16).reshape(8, 2)
4115-
4116-
@partial(jax.jit, in_shardings=(s, dict(y=s, z=s)))
4117-
def f(x, y, z):
4118-
return x * 2, y * 2, z * 2
4119-
4120-
with self.assertRaisesRegex(
4121-
ValueError,
4122-
'If you are using kwargs, in_shardings needs to be passed as a dict'):
4123-
f(np_inp, y=np_inp, z=np_inp)
4124-
41254018

41264019
class TempSharding(Sharding):
41274020

0 commit comments

Comments
 (0)