Replies: 1 comment 1 reply
-
Can you elaborate here? How does this trip up the compiler? Note that |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
We are given
arr = jax.numpy.array(...)
wherea.shape
can be anything, and we seek another arrayidxs
whose shape is(*a.shape, 1)
having roughly the property thatidxs[i] == i
for all validarr
-indicesi
. The precise intended behavior is that of the following code:Thus by including
idxs(arr.shape)
along witharr
as arguments to avmap
, it is useful anywhere one wants to know the index at which they are operating, like havingenumerate
in standard Python.The problem with the above
idxs
code is that, although its shape behavior is obviously statically determined by that of its argumentarr
, the construction passes througharr.shape
, which trips up the JIT compiler. There are functionszeros_like
,full_like
, and so on, that directly takearr
(notarr.shape
) and are respected by the JIT compiler. Is there some analogue ("idxs_like
") here?Beta Was this translation helpful? Give feedback.
All reactions