API design for obtaining a categorical sample #16988
Unanswered
blackblitz
asked this question in
Q&A
Replies: 1 comment
-
Sounds reasonable. One response:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
In JAX, there are two functions that can be used to obtain a categorical sample:
random.choice
andrandom.categorical
. Although their APIs are slightly different, they have basically the same functionality.random.categorical
returns integers that can be used as indices for the items used inrandom.choice
. Another related function israndom.randint
, but it is for sampling from a discrete uniform distribution, which is a special case of a categorical distribution. The following are my observations:1.Sampling with or without replacement:
random.categorical
only allows sampling with replacement, whilerandom.choice
allows sampling without replacement.2. Sampling algorithm:
random.categorical
uses the Gumbel-max trick, whilerandom.choice
uses a range of algorithms depending on the situation: (a) uniform, with replacement:randint
(b) uniform, without replacement: some kind of permutation? (c) non-uniform, with replacement: an interval-finding method (although I do not understand why uniform random variates are subtracted from one) (d) non-uniform, without replacement: Gumbel-top-k trick. The Gumbel-max trick and Gumbel-top-k trick are very related, but they are implemented separately in two different functions.3. Sampling in multiple axes:
random.categorical
allows sampling in multiple axes, butrandom.choice
allows sampling in one axis only (by default sampling the "rows" in axis 0).Would it not be better if we unify the APIs and allow the user to specify the sampling algorithm as an optional argument?
Beta Was this translation helpful? Give feedback.
All reactions