Replies: 1 comment
-
It seems Ok because the difference between the unnormalized log probability is preserved.. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am wondering whether the input of
jax.random.categorical
could be actually un-normalized log probability.I think that the input would be normalized log probability.
jax.random.categorical
leverages the Gumbel-max trick, used in various papers.In this paper, they explain the trick using normalized log probability and I think it makes sense.
However, when we use un-normalized log probability, the scale of sampled value from Gumbel distribution remains unchanged, but the scale of un-normalized log probability would be changed, which means that the sampled result would be affected depending on the scale of un-normalized log probability.
Therefore, my question is "could we use un-normalized log probability for this function?". My current thought is we use normalized log probability for this function.
To make further clarity, I describe my thought about the difference between un-normalized log probability and normalized log probability.
Also, I attach the current implementation of
jax.random.categorical
.Beta Was this translation helpful? Give feedback.
All reactions