@@ -151,7 +151,7 @@ def __eq__(self, other):
151
151
inspect .Parameter .POSITIONAL_OR_KEYWORD
152
152
)
153
153
154
- def validate_argnums (sig : inspect .Signature , argnums : tuple [int , ...], argnums_name : str ) -> None :
154
+ def _validate_argnums (sig : inspect .Signature , argnums : tuple [int , ...], argnums_name : str ) -> None :
155
155
"""
156
156
Validate that the argnums are sensible for a given function.
157
157
@@ -181,11 +181,14 @@ def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_n
181
181
inspect .Parameter .VAR_POSITIONAL
182
182
)
183
183
184
+
184
185
_KEYWORD_ARGUMENTS = (
185
186
inspect .Parameter .POSITIONAL_OR_KEYWORD ,
186
187
inspect .Parameter .KEYWORD_ONLY ,
187
188
)
188
- def validate_argnames (sig : inspect .Signature , argnames : tuple [str , ...], argnames_name : str ) -> None :
189
+ def _validate_argnames (
190
+ sig : inspect .Signature , argnames : tuple [str , ...], argnames_name : str
191
+ ) -> None :
189
192
"""
190
193
Validate that the argnames are sensible for a given function.
191
194
@@ -206,7 +209,6 @@ def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argname
206
209
elif param .kind in _INVALID_KEYWORD_ARGUMENTS :
207
210
invalid_kwargs .add (param_name )
208
211
209
-
210
212
# Check whether any kwargs are invalid due to position only
211
213
invalid_argnames = invalid_kwargs & set (argnames )
212
214
if invalid_argnames :
@@ -234,7 +236,6 @@ def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argname
234
236
"at the earliest." , SyntaxWarning )
235
237
236
238
237
-
238
239
def argnums_partial (f , dyn_argnums , args , require_static_args_hashable = True ):
239
240
dyn_argnums = _ensure_index_tuple (dyn_argnums )
240
241
dyn_argnums = _ensure_inbounds (False , len (args ), dyn_argnums )
@@ -506,8 +507,21 @@ def infer_argnums_and_argnames(
506
507
507
508
508
509
def resolve_argnums (
509
- fun , donate_argnums , donate_argnames , static_argnums , static_argnames
510
+ fun : Callable ,
511
+ donate_argnums : int | Sequence [int ] | None ,
512
+ donate_argnames : str | Iterable [str ] | None ,
513
+ static_argnums : int | Sequence [int ] | None ,
514
+ static_argnames : str | Iterable [str ] | None ,
510
515
) -> tuple [tuple [int , ...], tuple [str , ...], tuple [int , ...], tuple [str , ...]]:
516
+ """Validates and completes the argnum/argname specification for a jit.
517
+
518
+ * fills in any missing pieces (e.g., names given numbers, or vice versa),
519
+ * validates the argument names/numbers against the function signature,
520
+ * validates that donated and static arguments don't intersect.
521
+ * rebases the donated arguments so they index into the dynamic arguments,
522
+ (after static arguments have been removed), in the order that parameters
523
+ are passed into the compiled function.
524
+ """
511
525
try :
512
526
sig = inspect .signature (fun )
513
527
except ValueError as e :
@@ -535,18 +549,18 @@ def resolve_argnums(
535
549
sig , donate_argnums , donate_argnames )
536
550
537
551
# Validation
538
- validate_argnums (sig , static_argnums , "static_argnums" )
539
- validate_argnames (sig , static_argnames , "static_argnames" )
540
- validate_argnums (sig , donate_argnums , "donate_argnums" )
541
- validate_argnames (sig , donate_argnames , "donate_argnames" )
552
+ _validate_argnums (sig , static_argnums , "static_argnums" )
553
+ _validate_argnames (sig , static_argnames , "static_argnames" )
554
+ _validate_argnums (sig , donate_argnums , "donate_argnums" )
555
+ _validate_argnames (sig , donate_argnames , "donate_argnames" )
542
556
543
557
# Compensate for static argnums absorbing args
544
- assert_no_intersection (static_argnames , donate_argnames )
558
+ _assert_no_intersection (static_argnames , donate_argnames )
545
559
donate_argnums = rebase_donate_argnums (donate_argnums , static_argnums )
546
560
return donate_argnums , donate_argnames , static_argnums , static_argnames
547
561
548
562
549
- def assert_no_intersection (static_argnames , donate_argnames ):
563
+ def _assert_no_intersection (static_argnames , donate_argnames ):
550
564
out = set (static_argnames ).intersection (set (donate_argnames ))
551
565
if out :
552
566
raise ValueError (
0 commit comments