Memory issues with grad
and lax.scan
#25181
Unanswered
vboussange
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey there,
I am looking to implement a differentiable least cost path algorithm, working with large graphs. I have been considering the Bellman-Ford algorithm, as I am especially interested in having one to all vertex distances. Here is a custom implementation, that seems to work beautifully:
However, when trying to differentiate
bellman_ford
with respect to edge weights, I encounter aRESOURCE_EXHAUSTED: Out of memory
error.It seems that the
@jax.checkpoint
helps reduce memory usage, but it’s not sufficient to prevent memory build-up. I am actually suspecting a memory leak.Am I doing something wrong here? Would there be a better implementation to avoid memory exhaustion?
Side note: I’m considering creating a JAX-based package for graph utilities (all JIT-compatible and differentiable). If anyone is interested, let me know!
Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions