Replies: 1 comment
-
This doesn't answer your question, but have you considered using pyKeops to handle such large matrices in a smart way? https://www.kernel-operations.io/keops/index.html Although it states that bindings are only for Numpy and PyTorch, you can easily use https://github.com/rdyro/torch2jax as the go-between. This all supports auto-diff, vmap, etc etc etc. And the extra overhead is very minimal, in fact, due to the hyper-efficient way Keops works it can be faster, even with going through an intermediate step. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm encountering the situation of continuously increasing memory use with jax which doesn't quite make sense to me (see code below).
I have a very large array 500,000 x 4,000 that essentially need to vmap'ed to get 500,000 numbers (using func0 given below).
When I tried directly doing that I ran out of memory, so I thought, I'll split the vmap into chunks. I.e I'll apply vmap
to array [0:5000,:] then [5000:10000,:] and so forth to concatenate the results.
To my surprise when I did that when I look at the memory use I can still see it increasing by 3Gb per iteration (and in the end I run out of memory). I understand that each vmap will need to have information about the gradients, but that should be nowhere near 3gb.
So I'm wondering if this is a bug or am I am missing something here.
Thanks in advance !
The test code to illustrate the issue given below (requires 70GB ram to run with the current values of nspec,npix parameters)
Also I use CPU and jax 0.4.35
Beta Was this translation helpful? Give feedback.
All reactions