@@ -256,7 +256,7 @@ def is_key(var: core.Atom):
256
256
)
257
257
258
258
@weakref_lru_cache
259
- def get_jaxpr_type_signature (jaxpr : core .Jaxpr ) -> KeyReuseSignature :
259
+ def jaxpr_type_signature (jaxpr : core .Jaxpr ) -> KeyReuseSignature :
260
260
"""Parse the jaxpr to determine key reuse signature"""
261
261
consumed : dict [core .Atom , bool | np .ndarray ] = {}
262
262
forwards : dict [core .Atom , core .Atom ] = {} # map forwarded outputs to inputs.
@@ -340,18 +340,22 @@ def is_consumed(var: core.Atom):
340
340
)
341
341
342
342
343
+ def function_type_signature (fun : Callable [..., Any ], * args : Any ) -> KeyReuseSignature :
344
+ args_flat , in_tree = tree_util .tree_flatten (args )
345
+ in_avals_flat = [core .get_aval (arg ) for arg in args_flat ]
346
+ wrapped_fun , _ = api_util .flatten_fun_nokwargs (lu .wrap_init (fun ), in_tree )
347
+ jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (wrapped_fun , in_avals_flat )
348
+ return jaxpr_type_signature (jaxpr )
349
+
350
+
343
351
def check_key_reuse_jaxpr (jaxpr : core .Jaxpr ) -> None :
344
352
"""Check the jaxpr for key reuse."""
345
- get_jaxpr_type_signature (jaxpr )
353
+ jaxpr_type_signature (jaxpr )
346
354
347
355
348
356
def check_key_reuse (fun : Callable [..., Any ], / , * args : Any ) -> None :
349
357
"""Function to statically check key reuse."""
350
- args_flat , in_tree = tree_util .tree_flatten (args )
351
- in_avals_flat = [core .get_aval (arg ) for arg in args_flat ]
352
- wrapped_fun , _ = api_util .flatten_fun_nokwargs (lu .wrap_init (fun ), in_tree )
353
- jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (wrapped_fun , in_avals_flat )
354
- check_key_reuse_jaxpr (jaxpr )
358
+ function_type_signature (fun , * args )
355
359
356
360
357
361
#----------------------------------------------------------------------------------
@@ -384,17 +388,17 @@ def _concatenate_signature(eqn):
384
388
key_reuse_signatures_dynamic [lax .concatenate_p ] = _concatenate_signature
385
389
386
390
def _pjit_key_type_signature (eqn ):
387
- return get_jaxpr_type_signature (eqn .params ['jaxpr' ].jaxpr )
391
+ return jaxpr_type_signature (eqn .params ['jaxpr' ].jaxpr )
388
392
389
393
key_reuse_signatures_dynamic [pjit .pjit_p ] = _pjit_key_type_signature
390
394
391
395
def _shard_map_type_signature (eqn ):
392
- return get_jaxpr_type_signature (eqn .params ['jaxpr' ])
396
+ return jaxpr_type_signature (eqn .params ['jaxpr' ])
393
397
394
398
key_reuse_signatures_dynamic [shard_map_p ] = _shard_map_type_signature
395
399
396
400
def _cond_key_type_signature (eqn ):
397
- signatures = [get_jaxpr_type_signature (branch .jaxpr ) for branch in eqn .params ['branches' ]]
401
+ signatures = [jaxpr_type_signature (branch .jaxpr ) for branch in eqn .params ['branches' ]]
398
402
sinks = defaultdict (list )
399
403
sources = defaultdict (list )
400
404
for sig in signatures :
@@ -415,7 +419,7 @@ def _scan_key_type_signature(eqn):
415
419
jaxpr = eqn .params ['jaxpr' ].jaxpr
416
420
num_consts = eqn .params ['num_consts' ]
417
421
num_carry = eqn .params ['num_carry' ]
418
- signature = get_jaxpr_type_signature (jaxpr )
422
+ signature = jaxpr_type_signature (jaxpr )
419
423
420
424
# scan body should not consume key in constants
421
425
if any (np .any (s .mask ) for s in signature .sinks if s .idx < num_consts ):
@@ -446,8 +450,8 @@ def _while_key_type_signature(eqn):
446
450
body_jaxpr = eqn .params ['body_jaxpr' ].jaxpr
447
451
body_nconsts = eqn .params ['body_nconsts' ]
448
452
449
- cond_signature = get_jaxpr_type_signature (cond_jaxpr )
450
- body_signature = get_jaxpr_type_signature (body_jaxpr )
453
+ cond_signature = jaxpr_type_signature (cond_jaxpr )
454
+ body_signature = jaxpr_type_signature (body_jaxpr )
451
455
452
456
# Error if there are sinks among consts.
453
457
if any (np .any (s .mask ) for s in cond_signature .sinks if s .idx < cond_nconsts ):
@@ -489,7 +493,7 @@ def _remat_key_type_signature(eqn):
489
493
# Therefore, the differentiated pass is a no-op.
490
494
if eqn .params ['differentiated' ]:
491
495
return KeyReuseSignature ()
492
- return get_jaxpr_type_signature (eqn .params ['jaxpr' ])
496
+ return jaxpr_type_signature (eqn .params ['jaxpr' ])
493
497
494
498
key_reuse_signatures_dynamic [remat_p ] = _remat_key_type_signature
495
499
@@ -503,15 +507,15 @@ def key_reuse_impl(*args, **kwargs):
503
507
if prim == pjit .pjit_p :
504
508
funcname = "jit-compiled function"
505
509
jaxpr = kwargs ['jaxpr' ].jaxpr
506
- signature = get_jaxpr_type_signature (jaxpr )
510
+ signature = jaxpr_type_signature (jaxpr )
507
511
elif prim in key_reuse_signatures :
508
512
funcname = str (prim )
509
513
jaxpr = None
510
514
signature = key_reuse_signatures [prim ]
511
515
elif prim in key_reuse_signatures_dynamic :
512
516
funcname = str (prim )
513
517
jaxpr = jax .make_jaxpr (partial (prim .bind , ** kwargs ))(* args ).jaxpr
514
- signature = get_jaxpr_type_signature (jaxpr )
518
+ signature = jaxpr_type_signature (jaxpr )
515
519
else :
516
520
raise RuntimeError (f"Internal: no key reuse rule for primitive { prim } " )
517
521
signature .check_signature (* args , funcname = funcname )
0 commit comments