Replies: 1 comment 1 reply
-
You can turn off memory pre-allocation when using functions from other packages by setting the Here's an example: import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
import jax
import package_that_uses_jax Alternatively, you can also set this environment variable outside of your Python script, before running it. This will affect all JAX operations in your script, regardless of whether they come from JAX directly or from other packages. export XLA_PYTHON_CLIENT_PREALLOCATE=false Let me know if this helps! |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I notice that JAX will preallocate 75% of the total GPU memory when the first JAX operation is run. After I set up os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false', JAX no longer preallocates GPU memory after JAX operations in my script. However, I still find that memory preallocation occurs when I use a function from another package that utilizes operations in JAX.
My question is how I can turn off memory preallocation when I use functions from other packages?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions