Skip to content

Commit 332450b

Browse files
committed
[key reuse] add internal function_type_signature utility
1 parent e5a16a0 commit 332450b

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

jax/experimental/key_reuse/_core.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def is_key(var: core.Atom):
256256
)
257257

258258
@weakref_lru_cache
259-
def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
259+
def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
260260
"""Parse the jaxpr to determine key reuse signature"""
261261
consumed: dict[core.Atom, bool | np.ndarray] = {}
262262
forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs.
@@ -340,18 +340,22 @@ def is_consumed(var: core.Atom):
340340
)
341341

342342

343+
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
344+
args_flat, in_tree = tree_util.tree_flatten(args)
345+
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
346+
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
347+
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
348+
return jaxpr_type_signature(jaxpr)
349+
350+
343351
def check_key_reuse_jaxpr(jaxpr: core.Jaxpr) -> None:
344352
"""Check the jaxpr for key reuse."""
345-
get_jaxpr_type_signature(jaxpr)
353+
jaxpr_type_signature(jaxpr)
346354

347355

348356
def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
349357
"""Function to statically check key reuse."""
350-
args_flat, in_tree = tree_util.tree_flatten(args)
351-
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
352-
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
353-
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
354-
check_key_reuse_jaxpr(jaxpr)
358+
function_type_signature(fun, *args)
355359

356360

357361
#----------------------------------------------------------------------------------
@@ -384,17 +388,17 @@ def _concatenate_signature(eqn):
384388
key_reuse_signatures_dynamic[lax.concatenate_p] = _concatenate_signature
385389

386390
def _pjit_key_type_signature(eqn):
387-
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
391+
return jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
388392

389393
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
390394

391395
def _shard_map_type_signature(eqn):
392-
return get_jaxpr_type_signature(eqn.params['jaxpr'])
396+
return jaxpr_type_signature(eqn.params['jaxpr'])
393397

394398
key_reuse_signatures_dynamic[shard_map_p] = _shard_map_type_signature
395399

396400
def _cond_key_type_signature(eqn):
397-
signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
401+
signatures = [jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
398402
sinks = defaultdict(list)
399403
sources = defaultdict(list)
400404
for sig in signatures:
@@ -415,7 +419,7 @@ def _scan_key_type_signature(eqn):
415419
jaxpr = eqn.params['jaxpr'].jaxpr
416420
num_consts = eqn.params['num_consts']
417421
num_carry = eqn.params['num_carry']
418-
signature = get_jaxpr_type_signature(jaxpr)
422+
signature = jaxpr_type_signature(jaxpr)
419423

420424
# scan body should not consume key in constants
421425
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
@@ -446,8 +450,8 @@ def _while_key_type_signature(eqn):
446450
body_jaxpr = eqn.params['body_jaxpr'].jaxpr
447451
body_nconsts = eqn.params['body_nconsts']
448452

449-
cond_signature = get_jaxpr_type_signature(cond_jaxpr)
450-
body_signature = get_jaxpr_type_signature(body_jaxpr)
453+
cond_signature = jaxpr_type_signature(cond_jaxpr)
454+
body_signature = jaxpr_type_signature(body_jaxpr)
451455

452456
# Error if there are sinks among consts.
453457
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
@@ -489,7 +493,7 @@ def _remat_key_type_signature(eqn):
489493
# Therefore, the differentiated pass is a no-op.
490494
if eqn.params['differentiated']:
491495
return KeyReuseSignature()
492-
return get_jaxpr_type_signature(eqn.params['jaxpr'])
496+
return jaxpr_type_signature(eqn.params['jaxpr'])
493497

494498
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
495499

@@ -503,15 +507,15 @@ def key_reuse_impl(*args, **kwargs):
503507
if prim == pjit.pjit_p:
504508
funcname = "jit-compiled function"
505509
jaxpr = kwargs['jaxpr'].jaxpr
506-
signature = get_jaxpr_type_signature(jaxpr)
510+
signature = jaxpr_type_signature(jaxpr)
507511
elif prim in key_reuse_signatures:
508512
funcname = str(prim)
509513
jaxpr = None
510514
signature = key_reuse_signatures[prim]
511515
elif prim in key_reuse_signatures_dynamic:
512516
funcname = str(prim)
513517
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
514-
signature = get_jaxpr_type_signature(jaxpr)
518+
signature = jaxpr_type_signature(jaxpr)
515519
else:
516520
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
517521
signature.check_signature(*args, funcname=funcname)

tests/key_reuse_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,13 @@ def test_jaxpr_type_signature(self, primitive):
344344
func, *args = primitives_with_static_signatures[primitive]
345345
signature = _core.key_reuse_signatures[primitive]
346346
jaxpr = jax.make_jaxpr(func)(*args)
347-
self.assertEqual(signature, _core.get_jaxpr_type_signature(jaxpr.jaxpr))
347+
self.assertEqual(signature, _core.jaxpr_type_signature(jaxpr.jaxpr))
348+
349+
@parameterized.parameters(*primitives_with_static_signatures)
350+
def test_function_type_signature(self, primitive):
351+
func, *args = primitives_with_static_signatures[primitive]
352+
signature = _core.key_reuse_signatures[primitive]
353+
self.assertEqual(signature, _core.function_type_signature(func, *args))
348354

349355

350356
@jtu.with_config(jax_enable_key_reuse_checks=False)

0 commit comments

Comments
 (0)