Question regarding lifetime of JAX arrays #18330
Replies: 2 comments 6 replies
-
Thanks for the question!
That sounds accurate to me, modulo async dispatch (since the lifetime of arrays must extend until all consumers are finished, independent of the references the Python code holds). So there must be something going on here. I'm not familiar with the HF code. Could it have some caching? Could you share installation instructions for the HF code you're using? Or, even better, try to make a minimal reproducer which e.g. doesn't depend on extra libraries? |
Beta Was this translation helpful? Give feedback.
-
Hey, first of all, thanks a lot for your help! I did a clean reinstall of JAX and the issue disappeared. This is when I realized that I added some Here is a minimal example that shows the issue: import gc
import time
import jax
import jax.numpy as jnp
from flax.core import unfreeze
from jax.lib import xla_bridge
import flax.linen as nn
class ModelWrapper:
def __init__(self, set_breakpoint: bool):
if set_breakpoint:
jax.lax.cond(False, jax.debug.breakpoint, lambda: None)
self.params = unfreeze(
nn.Dense(456).init({"params": jax.random.PRNGKey(0)}, jnp.zeros((123,)))["params"])
def __del__(self):
print("Deleting")
def count_copies():
# Count the number of occurrences of the array with shape (123, 456), which appears exactly once in the
# parameters to understand how many copies of the parameters currently exist
return len([None for e in xla_bridge.get_backend().live_arrays() if e.shape == (123, 456)])
def mk_model(set_breakpoint: bool):
print(f"Set breakpoint: {set_breakpoint}")
print("mk_model_0", count_copies())
ModelWrapper(set_breakpoint)
print("mk_model_1", count_copies())
print("pre_0", count_copies())
mk_model(False)
print("post_0", count_copies())
print("")
print("pre_1", count_copies())
mk_model(True)
print("post_1", count_copies())
print("")
while True:
gc.collect()
print("gc", count_copies())
time.sleep(1.0) which outputs
From the logs it becomes apparent that the ModelWrapper instance is not destroyed anymore once a breakpoint is set in the constructor (even though it cannot be hit). I assume that the breakpoint stores a reference to the ModelWrapper instance as part of its trace. This behavior is a bit annoying for me as it changes the memory usage of my code when I set breakpoints and, thus, makes debugging memory issues with breakpoints harder. However, I think I can find a way to work around it by manually deleting these arrays or do you think there is a better way? Best, |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am new to JAX and currently facing severe (GPU) memory issues after porting my code from PyTorch. Hence, I am trying to understand how memory is managed by JAX.
My understanding so far was that outside of jitted functions, the memory of arrays stays allocated until their respective python object is destroyed. However, this is not what I am seeing in my code. As an example, take the following code snippet, in which I create a Vision Transformer instance and initialize it without returning anything (the code does not make sense like this but illustrates my problem):
Here, I would expect JAX to free up the arrays in params1 and params2 immediately after
mk_vit()
returns, since their corresponding python objects get destroyed at that time. However, I get the following output:Hence, even after
mk_vit()
returns, the arrays allocated forparams1
,params2
, andFlaxViTModel
are still visible inxla_bridge.get_backend().live_arrays()
and judging by the output ofnvidia-smi
still take up VRAM. Why is that and how can I prevent it from happening? Do I manually have to delete all unused parameters at the end of each function call?Thanks a lot in advance!
Best,
Tim
Beta Was this translation helpful? Give feedback.
All reactions