@@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
283
283
return _argnums_partial (f , dyn_argnums , tuple (fixed_args )), dyn_args
284
284
285
285
@lu .transformation2
286
- def _argnums_partial (f , dyn_argnums , fixed_args , * dyn_args , ** kwargs ):
286
+ def _argnums_partial (_fun , _dyn_argnums , _fixed_args , * dyn_args , ** kwargs ):
287
287
sentinel = object ()
288
- args = [sentinel ] * (len (fixed_args ) + len (dyn_args ))
289
- for i , arg in zip (dyn_argnums , dyn_args ):
288
+ args = [sentinel ] * (len (_fixed_args ) + len (dyn_args ))
289
+ for i , arg in zip (_dyn_argnums , dyn_args ):
290
290
args [i ] = arg
291
- fixed_args_ = iter (fixed_args )
291
+ fixed_args_ = iter (_fixed_args )
292
292
args = [next (fixed_args_ ).val if x is sentinel else x for x in args ]
293
293
assert next (fixed_args_ , sentinel ) is sentinel
294
- return f (* args , ** kwargs )
294
+ return _fun (* args , ** kwargs )
295
295
296
296
def argnames_partial_except (f : lu .WrappedFun , static_argnames : tuple [str , ...],
297
297
kwargs : dict [str , Any ]):
@@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
315
315
return _argnames_partial (f , WrapKwArgs (fixed_kwargs )), dyn_kwargs
316
316
317
317
@lu .transformation2
318
- def _argnames_partial (f , fixed_kwargs : WrapKwArgs , * args , ** dyn_kwargs ):
319
- kwargs = dict ({k : v .val for k , v in fixed_kwargs .val .items ()}, ** dyn_kwargs )
320
- return f (* args , ** kwargs )
318
+ def _argnames_partial (_fun , _fixed_kwargs : WrapKwArgs , * args , ** dyn_kwargs ):
319
+ kwargs = dict ({k : v .val for k , v in _fixed_kwargs .val .items ()}, ** dyn_kwargs )
320
+ return _fun (* args , ** kwargs )
321
321
322
322
323
323
@lru_cache (maxsize = 4096 )
@@ -438,9 +438,9 @@ def flat_out_axes(
438
438
return f , HashableFunction (out_axes , closure = (tuple (leaves ), treedef ))
439
439
440
440
@lu .transformation_with_aux2
441
- def _flat_out_axes (f , store , leaves , treedef , * args , ** kwargs ):
442
- ans = f (* args , ** kwargs )
443
- spec = tree_unflatten (treedef , leaves )
441
+ def _flat_out_axes (_fun , _store , _leaves , _treedef , * args , ** kwargs ):
442
+ ans = _fun (* args , ** kwargs )
443
+ spec = tree_unflatten (_treedef , _leaves )
444
444
try :
445
445
spec_flat = tuple (broadcast_prefix (spec , ans , is_leaf = lambda x : x is None ))
446
446
except ValueError :
@@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
451
451
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
452
452
"pmapped function's output." )
453
453
raise ValueError (msg ) from None
454
- store .store (spec_flat )
454
+ _store .store (spec_flat )
455
455
return ans
456
456
457
457
def check_callable (fun ):
@@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
687
687
for path , l in generate_key_paths (x ) if l is not static )
688
688
689
689
@lu .transformation_with_aux2
690
- def result_paths (f , store , * args , ** kwargs ):
690
+ def result_paths (_fun , _store , * args , ** kwargs ):
691
691
"linear_util transform to get output pytree paths of pre-flattened function."
692
- ans = f (* args , ** kwargs )
693
- store .store ([keystr (path ) for path , _ in generate_key_paths (ans )])
692
+ ans = _fun (* args , ** kwargs )
693
+ _store .store ([keystr (path ) for path , _ in generate_key_paths (ans )])
694
694
return ans
695
695
696
696
def jaxpr_debug_info (jaxpr : core .Jaxpr , trace_debug : TracingDebugInfo | None ,
0 commit comments