JAX hits swap OOM during parallel training on dual A100—even with 400 GB free RAM #29678
Unanswered
tudorcebere
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi JAX community,
First, thank you for the fantastic library!
System specs
What I’m doing
The issue
During jax.jit/XLA compile each process briefly allocates into swap. When both models compile at the same time, the 2 GB swap fills and the OS kills the job for OOM—even though ~400 GB of physical RAM remains unused. I verified this with htop that RAM stays mostly free while swap grows.
Is there a way to tell JAX/XLA never to touch swap (or to keep host allocations in RAM) during compilation? Alternatively, is there a flag or best practice for running parallel JAX compilations that avoids swap exhaustion?
Any pointers or heuristics would be much appreciated!
Thanks,
Tudor
Beta Was this translation helpful? Give feedback.
All reactions