Force a dynamic variable to be static #18347
-
Is there a way to tell the JAX compiler, "I know it's not supposed to be static but just make it static anyways!" Below is a simplified example of the issue:
Running this causes a
However it can be cumbersome to add extra arguments like this functions buried several layers deep when I know |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The way to force a variable to be static within a Unfortunately, in your particular case, because you are wrapping the function in To fix this you'll need to do two things:
Doing these together might look like this: from functools import partial
import typing
import jax
import jax.numpy as jnp
class Image(typing.NamedTuple):
resolution: jax.Array
values: jax.Array
def sample_image(image: Image) -> jax.Array:
return _sample_image(image.resolution, image.values)
@partial(jax.jit, static_argnames=["resolution"])
def _sample_image(resolution, values):
num_samples = int(1/resolution)
return values * num_samples
num_images = 3
images = Image(0.2, jnp.ones((num_images, 5)))
sampled_values = jax.vmap(sample_image, in_axes=(Image(None, 0),))(images) If you are using from jax.tree_util import register_pytree_node
register_pytree_node(
Image,
lambda image: ([image.values], [image.resolution]),
lambda aux_data, children: Image(resolution=aux_data[0], values=children[0])
)
@jax.jit
def sample_image(image: Image) -> jax.Array:
num_samples = int(1/image.resolution)
return image.values * num_samples
sample_image(Image(0.2, jnp.arange(5))) # now resolution is always treated as static |
Beta Was this translation helpful? Give feedback.
The way to force a variable to be static within a
jit
-compiled function is usingstatic_argnums
, as mentioned in thejax.jit
docs.Unfortunately, in your particular case, because you are wrapping the function in
vmap
, there is no way to mark the input as static: a vmapped value is always dynamic, by definition.To fix this you'll need to do two things:
vmap
call so you are not mapping over the value you want to be staticjit
call so that you mark the particular value as static viastatic_argnames
. One way to do this is by jit-compiling a sub-computation in which you pass the values directly to the function.Doing these together might look like this: