Replies: 3 comments 5 replies
-
Hi - thanks for the question! I don't know of any good answer to this. I'm not just saying that there's no function in JAX that will give you this, I'm saying I'm not sure it's possible to write an efficient implementation in JAX for what you have in mind. The current To select without replacement from a dynamic number of categories, I suspect you'll have to do some sort of iterative approach; something like:
That's obviously not JIT compatible, but you could express the same thing using I'm sure some sort of batched rejection sampling approach could improve this marginally, but the result is going to be very slow, particularly on accelerators where this kind of sequential processing is poorly supported. |
Beta Was this translation helpful? Give feedback.
-
Maybe you could say more about how are you going to use such a dynamically shaped array? I think, inside jit, even if one can get such an array, some masking trick is still need to use it. |
Beta Was this translation helpful? Give feedback.
-
Jit friendly:
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I would like to generate a fixed-size array of random uniform integers. The integers must be unique. A standard implementation is this:
This funciton
generate
will generate an array of sizes
, and each element in it is a random integer and is unique.Now I would like to jit the argument
n
. However, this is not possible becausejax.random.choice
requires thatn
must be static. Is there any alternative implementation/tweak?Beta Was this translation helpful? Give feedback.
All reactions