Skip to content

Commit 95892fd

Browse files
dougalmhawkinsp
authored andcommitted
Use private names for args in api_util to avoid shadowing kwargs keys.
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util. We probably shouldn't use linear_util...
1 parent 65b6088 commit 95892fd

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

jax/_src/api_util.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
283283
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
284284

285285
@lu.transformation2
286-
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
286+
def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs):
287287
sentinel = object()
288-
args = [sentinel] * (len(fixed_args) + len(dyn_args))
289-
for i, arg in zip(dyn_argnums, dyn_args):
288+
args = [sentinel] * (len(_fixed_args) + len(dyn_args))
289+
for i, arg in zip(_dyn_argnums, dyn_args):
290290
args[i] = arg
291-
fixed_args_ = iter(fixed_args)
291+
fixed_args_ = iter(_fixed_args)
292292
args = [next(fixed_args_).val if x is sentinel else x for x in args]
293293
assert next(fixed_args_, sentinel) is sentinel
294-
return f(*args, **kwargs)
294+
return _fun(*args, **kwargs)
295295

296296
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
297297
kwargs: dict[str, Any]):
@@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
315315
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
316316

317317
@lu.transformation2
318-
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
319-
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
320-
return f(*args, **kwargs)
318+
def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
319+
kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs)
320+
return _fun(*args, **kwargs)
321321

322322

323323
@lru_cache(maxsize=4096)
@@ -438,9 +438,9 @@ def flat_out_axes(
438438
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
439439

440440
@lu.transformation_with_aux2
441-
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
442-
ans = f(*args, **kwargs)
443-
spec = tree_unflatten(treedef, leaves)
441+
def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs):
442+
ans = _fun(*args, **kwargs)
443+
spec = tree_unflatten(_treedef, _leaves)
444444
try:
445445
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
446446
except ValueError:
@@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
451451
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
452452
"pmapped function's output.")
453453
raise ValueError(msg) from None
454-
store.store(spec_flat)
454+
_store.store(spec_flat)
455455
return ans
456456

457457
def check_callable(fun):
@@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
687687
for path, l in generate_key_paths(x) if l is not static)
688688

689689
@lu.transformation_with_aux2
690-
def result_paths(f, store, *args, **kwargs):
690+
def result_paths(_fun, _store, *args, **kwargs):
691691
"linear_util transform to get output pytree paths of pre-flattened function."
692-
ans = f(*args, **kwargs)
693-
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
692+
ans = _fun(*args, **kwargs)
693+
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
694694
return ans
695695

696696
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,

0 commit comments

Comments
 (0)