Estimate memory requirement from jaxpr #17058
-
Hi, If I have some jax function Could this be read from the generated Even if this will result in an inaccurate estimation, it could be very useful for tuning batch-sizes programatically for calls to Any help is greatly appreciated :). |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can get a sense of the memory use using the import jax
import jax.numpy as jnp
def f(x):
return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
x = jnp.arange(1000)
compiled = jax.jit(f).lower(x).compile()
print(compiled.cost_analysis())
I have to admit, though, I don't have a great sense of how to interpret these outputs; there's some info here that may be helpful: https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available |
Beta Was this translation helpful? Give feedback.
You can get a sense of the memory use using the
cost_analysis()
method of a lowered and compiled function. For example:I have to admit, though, I don't have a great sense of how to interpret these outputs; there's some info here that may be h…