Jax 0.4.1 requires much more memory than jax 0.2.17 for this code #14537
Replies: 1 comment 4 replies
-
Hi - I answered this question here. Answer copied below for posterity: If I suspect that older JAX versions may have had some additional optimization that was removed in later versions; given the form of your computation, the only thing that could have been is a factorization of the square difference to avoid instantiating the full matrix sum, which I vaguely recall was previously an XLA optimization but was removed because it's numerically unstable. But you can do such an optimization manually if you wish; here's an approach that seems to work, and the largest intermediate array it generates for the original inputs is of shape @jax.jit
def jit_bar2(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
Y = Y.reshape(Y.shape[0], -1)
Y2m = (Y ** 2).mean(-1)
YYTm = (Y @ Y.T) / Y.shape[1]
return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))
T = np.random.rand(50, 6, 3) # test with a smaller input
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
Could someone explain why this code takes much more memory (50GB+) on jax/jaxlib 0.41 than jax 0.2.17/jaxlib (0.1.68) (~ 2GB)
Thank you
Beta Was this translation helpful? Give feedback.
All reactions