File tree Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -213,6 +213,7 @@ def trace_context():
213
213
softmax_custom_jvp .value ,
214
214
enable_memories .value ,
215
215
disable_jit .value ,
216
+ enable_key_reuse_checks .value ,
216
217
jax_xla_profile_version .value ,
217
218
# Technically this affects jaxpr->stablehlo lowering, not tracing.
218
219
hlo_source_file_canonicalization_regex .value )
Original file line number Diff line number Diff line change @@ -608,15 +608,13 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
608
608
traced_bits_msg = "In random_bits, argument 0 is already consumed."
609
609
610
610
def test_clone_eager (self ):
611
- # TODO(b/329326258): run this test under JIT
612
- with jax .disable_jit ():
613
- key = jax .random .key (0 )
614
- key2 = jax .random .clone (key )
615
- self .assertIsNot (key , key2 )
611
+ key = jax .random .key (0 )
612
+ key2 = jax .random .clone (key )
613
+ self .assertIsNot (key , key2 )
616
614
617
- _ = jax .random .uniform (key )
618
- self .assertTrue (key ._consumed )
619
- self .assertFalse (key2 ._consumed )
615
+ _ = jax .random .uniform (key )
616
+ self .assertTrue (key ._consumed )
617
+ self .assertFalse (key2 ._consumed )
620
618
621
619
def test_simple_reuse_nojit (self ):
622
620
key = jax .random .key (0 )
You can’t perform that action at this time.
0 commit comments