Skip to content

Generating a array of random choices on the CPU, not GPU #21974

Answered by dfm
Peter-Vincent asked this question in Q&A
Discussion options

You must be logged in to vote

The reason for that ConcretizationTypeError is that some of the arguments to random.choice must be static. You should be able to get your proposed jit version to work using something like:

cpu_func = jax.jit(random.choice, static_argnums=(1, 2), device=device_cpu)
cpu_func(key, max_val, shape)

Hope this helps!

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@Peter-Vincent
Comment options

Answer selected by Peter-Vincent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants