|
16 | 16 | Experimental Key Reuse Checking
|
17 | 17 | -------------------------------
|
18 | 18 |
|
19 |
| -This module contains **experimental** functionality for detecting re-use of random |
20 |
| -keys within JAX programs. It is under active development and the APIs here are likely |
21 |
| -to change. The usage below requires JAX version 0.4.26 or newer. |
| 19 | +This module contains **experimental** functionality for detecting reuse of random |
| 20 | +keys within JAX programs. It is under active development and the APIs here are |
| 21 | +likely to change. The usage below requires JAX version 0.4.26 or newer. |
22 | 22 |
|
23 |
| -Key reuse checking can be enabled using the `jax_enable_key_reuse_checks` configuration:: |
| 23 | +Key reuse checking can be enabled using the ``jax_enable_key_reuse_checks`` configuration. |
| 24 | +This can be set globally using:: |
| 25 | +
|
| 26 | + >>> jax.config.update('jax_enable_key_reuse_checks', True) # doctest: +SKIP |
| 27 | +
|
| 28 | +Or it can be enabled locally with the :func:`jax.enable_key_reuse_checks` context manager. |
| 29 | +When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`:: |
24 | 30 |
|
25 | 31 | >>> import jax
|
26 |
| - >>> jax.config.update('jax_enable_key_reuse_checks', True) |
27 |
| - >>> key = jax.random.key(0) |
28 |
| - >>> jax.random.normal(key) |
29 |
| - Array(-0.20584226, dtype=float32) |
30 |
| - >>> jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL |
| 32 | + >>> with jax.enable_key_reuse_checks(True): |
| 33 | + ... key = jax.random.key(0) |
| 34 | + ... val1 = jax.random.normal(key) |
| 35 | + ... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL |
31 | 36 | Traceback (most recent call last):
|
32 | 37 | ...
|
33 | 38 | KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
|
34 | 39 |
|
35 |
| -This flag can also be controlled locally using the :func:`jax.enable_key_reuse_checks` |
36 |
| -context manager:: |
37 |
| -
|
38 |
| - >>> with jax.enable_key_reuse_checks(False): |
39 |
| - ... print(jax.random.normal(key)) |
40 |
| - -0.20584226 |
| 40 | +The key reuse checker is currently experimental, but in the future we will likely |
| 41 | +enable it by default. |
41 | 42 | """
|
0 commit comments