Why is jax failing to allocate memory despite significant memory being free? #30276
Replies: 2 comments 4 replies
-
It would be useful if you could provide a complete minimal example (self-contained script that can be executed by another user to see the behavior you're seeing). I gave up trying to answer on StackOverflow because there was nothing more I could do with the information you provided. I still suspect that intermediate allocations are to blame. In this line: self.ne = n_e0 * (1.0 + s1 * self.XX / self.x_length) * (1 + s2 * jnp.cos(2 * jnp.pi * self.YY / Ly)) I count 11 intermediate arrays of the same size as |
Beta Was this translation helpful? Give feedback.
-
@nouiz maybe? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have been having issues allocating memory on the GPU with jax. I am running on a cluster that gives me access to a RTX6000 (24 GB vram) which jax is attempting to allocate to.
Output of jax.print_environment_info() produces a nvidia-smi table showing that memory is free on the GPU and yet jax fails to allocate it with
RESOURCE_EXHAUSTED
error, failing to allocate 3.64 GB, despite over 16 GB being free on the card.I have attempted to fix this issue with jax memory configuration settings:
My initial ideas where either there was an allocation limit at 4 GB due to 32 bit addressing (hence why I have now ensured 64 bit is enabled) or that it was having difficulties allocating above the default "XLA_PYTHON_CLIENT_MEM_FRACTION" = "0.75" memory reservation. Hence my increasing the default memory allocation to 0.9 and trying ensure that it could allocate past that by setting "TF_GPU_ALLOCATOR" = "cuda_malloc_async".
This is failing when trying to allocate self.ne_nc, which is a scalar of a meshgrid of size 992^3, these meshgrids compose the majority of the successfully allocated 7632 MB (2 of them have previously been allocated, both of the same 3.64 GB size as they are 32 bit grids).
This error cropped up as part of a batch script running my program with different parameters to test memory usage at different meshgrid sizes. The 992 was the first failure, and it continued to fail at each test after this as size increased, until eventually the meshgrids themselves stopped allocating as they got larger.
Meshgrid only attempts to allocate 2 out of 3 outputted grids, so in theory there should still be enough memory free to allocate them despite them being larger.
After this it continues to fail but this was expected as it will attempt to allocate more memory than available on the card at this domain size.
Does anyone have any suggestions as to why its failing? There is no reason why it should fail to allocate, both the nvidia-smi reporter and my own estimations of memory allocation based on program input show that there is more than enough memory free on the card for these allocations. I suspect that changing the jax memory configuration flags should fix it but as of yet I have been unable to find a setup that works. Many thanks.
Complete working example as per request by @jakevdp:
Link to a stack overflow post about this issue: https://stackoverflow.com/questions/79701945/jax-unable-to-allocate-memory-despite-memory-being-free
Beta Was this translation helpful? Give feedback.
All reactions