lax cond allocating memory for inactive branch? #17463
-
Moving the discussion here. The original discussion was at #3103. In the original thread, it was pointed out that both branches of However, it seems that during the "abstract" phase, memory is being allocated, and causes oom when one of the branch is using excessive memory. If we add My hypothesis is that XLA preallocates memory for every tensor, therefore before the predicate is examined, XLA has to allocate for both branches. Is it possible to delay the memory allocation? import jax
import jax.numpy as jnp
from jax import lax
def a(key):
return jax.random.normal(key, (32,)*6).sum()
def b(key):
return jnp.array(1.)
lax.cond(True, b, a, jax.random.PRNGKey(0)) # this oom
jax.jit(lambda x, pred: lax.cond(pred, b, a, x))(jax.random.PRNGKey(0), True) # this oom too |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Would appreciate a lot if anyone with the relevant knowledge could share something on this. |
Beta Was this translation helpful? Give feedback.
So XLA performs all memory allocation up-front, i.e. there is no dynamic memory allocation. In particular that means it has to allocate enough memory to support either branch of the cond.