Use "real" python code inside of jit #18130
-
My functions often compute values that do not depend on the contents of any of the inputs, but rather on the shape of the inputs or not on the inputs at all. Is there any way to stage these precomputations out of jit?
This behavior is exactly what we expect from jit. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 10 replies
-
It sounds like you're looking for class Foo:
@cached_property
def some_property(self) -> Array:
with jax.ensure_compile_time_eval():
return jnp.ones((10, 10)) |
Beta Was this translation helpful? Give feedback.
-
To add on to the answer here. Compiling will generally precompute those values that depend only on array shapes and known constants. However, there are some exceptions to this. One example is that Another exception is if the number of dimensions in the constants get too large. Past some point the constant propagation will judge that it's too expensive, and defer the computation to runtime. Wrt the specifics of your problem here, indeed explicitly precomputing the desired matrix in advance is probably what you want. By the way, if you're solving differential equations, then you might like Diffrax. It's mostly targeting ODE/SDE solving, but it can handle some semidiscretised PDEs as well. |
Beta Was this translation helpful? Give feedback.
It sounds like you're looking for
jax.ensure_compile_time_eval
: