Skip to content

Commit d8f231a

Browse files
author
jax authors
committed
Merge pull request #20250 from jakevdp:key-reuse-jit
PiperOrigin-RevId: 616971171
2 parents d6588ac + ae4e273 commit d8f231a

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

jax/_src/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def trace_context():
213213
softmax_custom_jvp.value,
214214
enable_memories.value,
215215
disable_jit.value,
216+
enable_key_reuse_checks.value,
216217
jax_xla_profile_version.value,
217218
# Technically this affects jaxpr->stablehlo lowering, not tracing.
218219
hlo_source_file_canonicalization_regex.value)

tests/key_reuse_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -608,15 +608,13 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
608608
traced_bits_msg = "In random_bits, argument 0 is already consumed."
609609

610610
def test_clone_eager(self):
611-
# TODO(b/329326258): run this test under JIT
612-
with jax.disable_jit():
613-
key = jax.random.key(0)
614-
key2 = jax.random.clone(key)
615-
self.assertIsNot(key, key2)
611+
key = jax.random.key(0)
612+
key2 = jax.random.clone(key)
613+
self.assertIsNot(key, key2)
616614

617-
_ = jax.random.uniform(key)
618-
self.assertTrue(key._consumed)
619-
self.assertFalse(key2._consumed)
615+
_ = jax.random.uniform(key)
616+
self.assertTrue(key._consumed)
617+
self.assertFalse(key2._consumed)
620618

621619
def test_simple_reuse_nojit(self):
622620
key = jax.random.key(0)

0 commit comments

Comments
 (0)