How to identify amount of GPU memory used? #26252
Unanswered
Jacob-Stevens-Haas
asked this question in
Q&A
Replies: 2 comments
-
For memory use, I use the following function (found in https://github.com/cabouman/mbirjax/blob/main/mbirjax/memory_stats.py): import jax def get_memory_stats(print_results=True, file=None):
|
Beta Was this translation helpful? Give feedback.
0 replies
-
You can use |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I've got a few GPUs on some machines that our lab shares. I'm trying to figure out how I can improve the experience of running jax processes, mostly for those who run offline experiments but also for those who play in jupyter notebooks. I've read the guide on memory allocation, but still want to know:
a) How does jax choose a device for an operation if
jax_default_device
hasn't been set?a.5) When does it decide to use a different device from the default?
b) Jax preallocates 75% of a GPU, and I want to find out how much I'm actually using. Short of using a profiler, which seems more involved than necessary, is there a way to see peak GPU memory that jax actually uses? (e.g in DEBUG logs?)
c) If not, is there a global
block_until_ready()
? E.g. if I have a functiondo_jax_work(some, jax, arrays)
, how can I verify that all work is, in fact, done?d) Is there a way to queue jax processes to wait until a GPU is available, rather than run into errors if they can't meet their preallocation?
Thanks for any help and guidance you can provide!
Beta Was this translation helpful? Give feedback.
All reactions