-
I'm building a system in JAX that needs to sequentially run a series of procedures. For example: from lib import f1, f2, f3
for i in range(100):
f1(i)
for i in range(100):
f2(i)
for i in range(100):
f3(i) The problem is that each of f1, f2, f3 is very memory-intensive, and GPU memory can only accommodate one at a time. Because of the way JAX's compilation cache works, f2 fails at the first call because the GPU runs out of memory while compiling its subroutines (all the memory is used up by f1). It would be nice to be able to force JAX to clear its compilation cache between the loops, because f1's code is not necessary to keep around to run f2. Is there a way to do this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 11 replies
-
does from jax import jit
@jit
def f(x):
return x + 1
print(f._cache_size())
# 0
f(1.0)
print(f._cache_size())
# 1
f(1)
print(f._cache_size())
# 2
f._clear_cache()
print(f._cache_size())
# 0 If not, there may be references hanging around that will need to be garbage collected somehow. |
Beta Was this translation helpful? Give feedback.
We finally added a
jax.clear_caches
! See #15448.