Skip to content

Commit 8949a63

Browse files
committed
[key reuse] rename flag to jax_debug_key_reuse
1 parent cd79e71 commit 8949a63

18 files changed

+44
-44
lines changed

jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from jax._src.config import (
4848
config as config,
4949
enable_checks as enable_checks,
50-
enable_key_reuse_checks as enable_key_reuse_checks,
50+
debug_key_reuse as debug_key_reuse,
5151
check_tracer_leaks as check_tracer_leaks,
5252
checking_leaks as checking_leaks,
5353
enable_custom_prng as enable_custom_prng,

jax/_src/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def trace_context():
213213
softmax_custom_jvp.value,
214214
enable_memories.value,
215215
disable_jit.value,
216-
enable_key_reuse_checks.value,
216+
debug_key_reuse.value,
217217
jax_xla_profile_version.value,
218218
# Technically this affects jaxpr->stablehlo lowering, not tracing.
219219
hlo_source_file_canonicalization_regex.value)
@@ -930,8 +930,8 @@ def update_thread_local_jit_state(**kw):
930930
default=False,
931931
help='Turn on invariant checking for JAX internals. Makes things slower.')
932932

933-
enable_key_reuse_checks = define_bool_state(
934-
name='jax_enable_key_reuse_checks',
933+
debug_key_reuse = define_bool_state(
934+
name='jax_debug_key_reuse',
935935
default=False,
936936
help=('Turn on experimental key reuse checking. With this configuration enabled,'
937937
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'

jax/_src/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2861,7 +2861,7 @@ def ctx_factory():
28612861
raise JaxprTypeError(msg) from None
28622862

28632863
# Run key reuse checker after validating jaxpr:
2864-
if config.enable_key_reuse_checks.value:
2864+
if config.debug_key_reuse.value:
28652865
# Import here to avoid circular imports
28662866
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
28672867
check_key_reuse_jaxpr(jaxpr)

jax/_src/errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,12 +661,12 @@ def __init__(self, msg: str):
661661
class KeyReuseError(JAXTypeError):
662662
"""
663663
This error occurs when a PRNG key is reused in an unsafe manner.
664-
Key reuse is checked only when `jax_enable_key_reuse_checks` is
664+
Key reuse is checked only when `jax_debug_key_reuse` is
665665
set to `True`.
666666
667667
Here is a simple example of code that would lead to such an error::
668668
669-
>>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP
669+
>>> with jax.debug_key_reuse(True): # doctest: +SKIP
670670
... key = jax.random.key(0)
671671
... value = jax.random.uniform(key)
672672
... new_value = jax.random.uniform(key)

jax/_src/pjit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _get_fastpath_data(
236236
# no ref state effects
237237
and not any(isinstance(e, RefEffect) for e in effects)
238238
# no prng reuse checking
239-
and not (config.enable_key_reuse_checks.value and any(
239+
and not (config.debug_key_reuse.value and any(
240240
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
241241
for arg in (*args_flat, *out_flat)))
242242
)
@@ -1150,7 +1150,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
11501150
if not config.dynamic_shapes.value and not attrs_tracked:
11511151
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
11521152

1153-
if config.enable_key_reuse_checks.value:
1153+
if config.debug_key_reuse.value:
11541154
# Import here to avoid circular imports
11551155
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
11561156
check_key_reuse_jaxpr(jaxpr)

jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343

4444
@jtu.with_config(jax_legacy_prng_key='allow',
45-
jax_enable_key_reuse_checks=False)
45+
jax_debug_key_reuse=False)
4646
class JaxPrimitiveTest(jtu.JaxTestCase):
4747

4848
# This test runs for all primitive harnesses. For each primitive "xxx" the

jax/experimental/jax2tf/tests/tf_test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
158158
@jtu.with_config(jax_numpy_rank_promotion="allow",
159159
jax_numpy_dtype_promotion='standard',
160160
jax_legacy_prng_key="allow",
161-
jax_enable_key_reuse_checks=False)
161+
jax_debug_key_reuse=False)
162162
class JaxToTfTestCase(jtu.JaxTestCase):
163163
# We want most tests to use the maximum available version, from the locally
164164
# installed tfxla module and export.

jax/experimental/key_reuse/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@
2020
keys within JAX programs. It is under active development and the APIs here are
2121
likely to change. The usage below requires JAX version 0.4.26 or newer.
2222
23-
Key reuse checking can be enabled using the ``jax_enable_key_reuse_checks`` configuration.
23+
Key reuse checking can be enabled using the ``jax_debug_key_reuse`` configuration.
2424
This can be set globally using::
2525
26-
>>> jax.config.update('jax_enable_key_reuse_checks', True) # doctest: +SKIP
26+
>>> jax.config.update('jax_debug_key_reuse', True) # doctest: +SKIP
2727
28-
Or it can be enabled locally with the :func:`jax.enable_key_reuse_checks` context manager.
28+
Or it can be enabled locally with the :func:`jax.debug_key_reuse` context manager.
2929
When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`::
3030
3131
>>> import jax
32-
>>> with jax.enable_key_reuse_checks(True):
32+
>>> with jax.debug_key_reuse(True):
3333
... key = jax.random.key(0)
3434
... val1 = jax.random.normal(key)
3535
... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL

jax/experimental/key_reuse/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def _remat_key_type_signature(eqn):
530530
def key_reuse_impl_rule(prim, original_rule):
531531
@wraps(original_rule)
532532
def key_reuse_impl(*args, **kwargs):
533-
if config.enable_key_reuse_checks.value:
533+
if config.debug_key_reuse.value:
534534
if prim == pjit.pjit_p:
535535
funcname = "jit-compiled function"
536536
jaxpr = kwargs['jaxpr'].jaxpr

tests/batching_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def body_fn(uk):
962962
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
963963
return u
964964

965-
with jax.enable_key_reuse_checks(False):
965+
with jax.debug_key_reuse(False):
966966
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
967967

968968
def testEmptyTuples(self):

0 commit comments

Comments
 (0)