Skip to content

Commit c7bc95d

Browse files
author
jax authors
committed
Merge pull request #20147 from jakevdp:key-reuse-fix-clone
PiperOrigin-RevId: 614059823
2 parents 5e039f7 + d1e49f9 commit c7bc95d

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

jax/_src/random.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from jax._src import config
3232
from jax._src import core
33+
from jax._src import dispatch
3334
from jax._src import dtypes
3435
from jax._src import prng
3536
from jax._src import xla_bridge
@@ -2615,7 +2616,7 @@ def binomial(
26152616

26162617
# Functions related to key reuse checking
26172618
random_clone_p = core.Primitive("random_clone")
2618-
random_clone_p.def_impl(lambda x: x)
2619+
dispatch.simple_impl(random_clone_p)
26192620
random_clone_p.def_abstract_eval(lambda x: x)
26202621
batching.defvectorized(random_clone_p)
26212622
mlir.register_lowering(random_clone_p, lambda _, k: [k])

tests/key_reuse_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
595595
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
596596
traced_bits_msg = "In random_bits, argument 0 is already consumed."
597597

598+
def test_clone_eager(self):
599+
key = jax.random.key(0)
600+
key2 = jax.random.clone(key)
601+
self.assertIsNot(key, key2)
602+
603+
_ = jax.random.uniform(key)
604+
self.assertTrue(key._consumed)
605+
self.assertFalse(key2._consumed)
606+
598607
def test_simple_reuse_nojit(self):
599608
key = jax.random.key(0)
600609
_ = jax.random.bits(key)

0 commit comments

Comments
 (0)