Skip to content

Commit 4963ae1

Browse files
committed
doctest: avoid modifying global flag state
1 parent cc06836 commit 4963ae1

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

jax/experimental/key_reuse/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,27 @@
1616
Experimental Key Reuse Checking
1717
-------------------------------
1818
19-
This module contains **experimental** functionality for detecting re-use of random
20-
keys within JAX programs. It is under active development and the APIs here are likely
21-
to change. The usage below requires JAX version 0.4.26 or newer.
19+
This module contains **experimental** functionality for detecting reuse of random
20+
keys within JAX programs. It is under active development and the APIs here are
21+
likely to change. The usage below requires JAX version 0.4.26 or newer.
2222
23-
Key reuse checking can be enabled using the `jax_enable_key_reuse_checks` configuration::
23+
Key reuse checking can be enabled using the ``jax_enable_key_reuse_checks`` configuration.
24+
This can be set globally using::
25+
26+
>>> jax.config.update('jax_enable_key_reuse_checks', True) # doctest: +SKIP
27+
28+
Or it can be enabled locally with the :func:`jax.enable_key_reuse_checks` context manager.
29+
When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`::
2430
2531
>>> import jax
26-
>>> jax.config.update('jax_enable_key_reuse_checks', True)
27-
>>> key = jax.random.key(0)
28-
>>> jax.random.normal(key)
29-
Array(-0.20584226, dtype=float32)
30-
>>> jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL
32+
>>> with jax.enable_key_reuse_checks(True):
33+
... key = jax.random.key(0)
34+
... val1 = jax.random.normal(key)
35+
... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL
3136
Traceback (most recent call last):
3237
...
3338
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
3439
35-
This flag can also be controlled locally using the :func:`jax.enable_key_reuse_checks`
36-
context manager::
37-
38-
>>> with jax.enable_key_reuse_checks(False):
39-
... print(jax.random.normal(key))
40-
-0.20584226
40+
The key reuse checker is currently experimental, but in the future we will likely
41+
enable it by default.
4142
"""

0 commit comments

Comments
 (0)