Memory-efficient way to obtain a categorical sample #16862
Unanswered
blackblitz
asked this question in
Q&A
Replies: 1 comment 2 replies
-
JAX's categorical sampler works by performing an argmax over a gumbel distribution: https://github.com/google/jax/blob/a03d6e66137e0bb79350f5af81d39cb4c27e4a70/jax/_src/random.py#L1501-L1504 I think this is where the large memory footprint is coming from. One way you could address this is by doing a |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I would like to obtain a large sample of size
n
from a categorical distribution with a large number of categoriesm
. However, I often run into an out-of-memory error. I have noticed from the traceback message that the function builds an array with shape(n, m)
, so the memory requirement is4mn
. For example, if I take a sample of size 10000 from a categorical distribution with 10000 categories, then I need4 * 10^10
bytes, which is very huge. Is there a way to do that with less memory?Beta Was this translation helpful? Give feedback.
All reactions