-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
I was using jax
backend + jax jit
to optimize parameters of a quimb
tensor network but with some simulated noise. Noticed that the outputs were not random, and I believe it is due to jit
tracing out the global key in autoray.
Minimal example:
import jax
import autoray as ar
@jax.jit
def sampler():
return ar.do("random.normal", like='jax')
# return jax.random.normal(key=ar.autoray.jax_random_get_key())
for _ in range(10):
print(sampler())
0.11975035
0.11975035
0.11975035
0.11975035
0.11975035
0.11975035
0.11975035
0.11975035
0.11975035
0.11975035
Any ideas of a clean workaround that can keep code backend-agnostic? I guess this fundamentally breaks the assumption of a changing global RNG state that all the other backends have, but I would like to keep jit
as it does massively increase the code's performance.
Metadata
Metadata
Assignees
Labels
No labels