Skip to content

Force a dynamic variable to be static #18347

Answered by jakevdp
chrisflesher asked this question in Q&A
Discussion options

You must be logged in to vote

The way to force a variable to be static within a jit-compiled function is using static_argnums, as mentioned in the jax.jit docs.

Unfortunately, in your particular case, because you are wrapping the function in vmap, there is no way to mark the input as static: a vmapped value is always dynamic, by definition.

To fix this you'll need to do two things:

  1. adjust your vmap call so you are not mapping over the value you want to be static
  2. adjust your jit call so that you mark the particular value as static via static_argnames. One way to do this is by jit-compiling a sub-computation in which you pass the values directly to the function.

Doing these together might look like this:

from functools 

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@chrisflesher
Comment options

Answer selected by chrisflesher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants