Skip to content

Commit d6f074b

Browse files
hawkinspjax authors
authored andcommitted
Improve documentation and types for api_util.resolve_argnums.
Prefix some private helpers with a _. No functional changes intended. PiperOrigin-RevId: 617627335
1 parent 3f13308 commit d6f074b

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

jax/_src/api_util.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __eq__(self, other):
151151
inspect.Parameter.POSITIONAL_OR_KEYWORD
152152
)
153153

154-
def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
154+
def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
155155
"""
156156
Validate that the argnums are sensible for a given function.
157157
@@ -181,11 +181,14 @@ def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_n
181181
inspect.Parameter.VAR_POSITIONAL
182182
)
183183

184+
184185
_KEYWORD_ARGUMENTS = (
185186
inspect.Parameter.POSITIONAL_OR_KEYWORD,
186187
inspect.Parameter.KEYWORD_ONLY,
187188
)
188-
def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str) -> None:
189+
def _validate_argnames(
190+
sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str
191+
) -> None:
189192
"""
190193
Validate that the argnames are sensible for a given function.
191194
@@ -206,7 +209,6 @@ def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argname
206209
elif param.kind in _INVALID_KEYWORD_ARGUMENTS:
207210
invalid_kwargs.add(param_name)
208211

209-
210212
# Check whether any kwargs are invalid due to position only
211213
invalid_argnames = invalid_kwargs & set(argnames)
212214
if invalid_argnames:
@@ -234,7 +236,6 @@ def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argname
234236
"at the earliest.", SyntaxWarning)
235237

236238

237-
238239
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
239240
dyn_argnums = _ensure_index_tuple(dyn_argnums)
240241
dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums)
@@ -506,8 +507,21 @@ def infer_argnums_and_argnames(
506507

507508

508509
def resolve_argnums(
509-
fun, donate_argnums, donate_argnames, static_argnums, static_argnames
510+
fun: Callable,
511+
donate_argnums: int | Sequence[int] | None,
512+
donate_argnames: str | Iterable[str] | None,
513+
static_argnums: int | Sequence[int] | None,
514+
static_argnames: str | Iterable[str] | None,
510515
) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]:
516+
"""Validates and completes the argnum/argname specification for a jit.
517+
518+
* fills in any missing pieces (e.g., names given numbers, or vice versa),
519+
* validates the argument names/numbers against the function signature,
520+
* validates that donated and static arguments don't intersect.
521+
* rebases the donated arguments so they index into the dynamic arguments,
522+
(after static arguments have been removed), in the order that parameters
523+
are passed into the compiled function.
524+
"""
511525
try:
512526
sig = inspect.signature(fun)
513527
except ValueError as e:
@@ -535,18 +549,18 @@ def resolve_argnums(
535549
sig, donate_argnums, donate_argnames)
536550

537551
# Validation
538-
validate_argnums(sig, static_argnums, "static_argnums")
539-
validate_argnames(sig, static_argnames, "static_argnames")
540-
validate_argnums(sig, donate_argnums, "donate_argnums")
541-
validate_argnames(sig, donate_argnames, "donate_argnames")
552+
_validate_argnums(sig, static_argnums, "static_argnums")
553+
_validate_argnames(sig, static_argnames, "static_argnames")
554+
_validate_argnums(sig, donate_argnums, "donate_argnums")
555+
_validate_argnames(sig, donate_argnames, "donate_argnames")
542556

543557
# Compensate for static argnums absorbing args
544-
assert_no_intersection(static_argnames, donate_argnames)
558+
_assert_no_intersection(static_argnames, donate_argnames)
545559
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
546560
return donate_argnums, donate_argnames, static_argnums, static_argnames
547561

548562

549-
def assert_no_intersection(static_argnames, donate_argnames):
563+
def _assert_no_intersection(static_argnames, donate_argnames):
550564
out = set(static_argnames).intersection(set(donate_argnames))
551565
if out:
552566
raise ValueError(

0 commit comments

Comments
 (0)