File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change 30
30
31
31
from jax ._src import config
32
32
from jax ._src import core
33
+ from jax ._src import dispatch
33
34
from jax ._src import dtypes
34
35
from jax ._src import prng
35
36
from jax ._src import xla_bridge
@@ -2615,7 +2616,7 @@ def binomial(
2615
2616
2616
2617
# Functions related to key reuse checking
2617
2618
random_clone_p = core .Primitive ("random_clone" )
2618
- random_clone_p . def_impl ( lambda x : x )
2619
+ dispatch . simple_impl ( random_clone_p )
2619
2620
random_clone_p .def_abstract_eval (lambda x : x )
2620
2621
batching .defvectorized (random_clone_p )
2621
2622
mlir .register_lowering (random_clone_p , lambda _ , k : [k ])
Original file line number Diff line number Diff line change @@ -595,6 +595,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
595
595
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
596
596
traced_bits_msg = "In random_bits, argument 0 is already consumed."
597
597
598
+ def test_clone_eager (self ):
599
+ key = jax .random .key (0 )
600
+ key2 = jax .random .clone (key )
601
+ self .assertIsNot (key , key2 )
602
+
603
+ _ = jax .random .uniform (key )
604
+ self .assertTrue (key ._consumed )
605
+ self .assertFalse (key2 ._consumed )
606
+
598
607
def test_simple_reuse_nojit (self ):
599
608
key = jax .random .key (0 )
600
609
_ = jax .random .bits (key )
You can’t perform that action at this time.
0 commit comments