Skip to content

Estimate memory requirement from jaxpr #17058

Answered by jakevdp
joeryjoery asked this question in Q&A
Discussion options

You must be logged in to vote

You can get a sense of the memory use using the cost_analysis() method of a lowered and compiled function. For example:

import jax
import jax.numpy as jnp

def f(x): 
  return jnp.sin(x) ** 2 + jnp.cos(x) ** 2

x = jnp.arange(1000)
compiled = jax.jit(f).lower(x).compile()
print(compiled.cost_analysis())
[{'transcendentals': 2000.0,
  'bytes accessed operand 0 {}': 4000.0,
  'bytes accessed output {}': 4000.0,
  'bytes accessed': 8000.0,
  'utilization operand 1 {}': 3.0,
  'bytes accessed operand 1 {}': 12000.0,
  'flops': 4000.0,
  'utilization operand 0 {}': 1.0}]

I have to admit, though, I don't have a great sense of how to interpret these outputs; there's some info here that may be h…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by joeryjoery
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants