Status of the JAX dynamic API #18476
Unanswered
sergei-mironov
asked this question in
General
Replies: 1 comment 4 replies
-
Hi - I'm a bit confused by what you're asking here. One clarification: it looks like most of your examples fail because they try to create a dynamically-sized array before slicing/indexing it. Is that intentional? For example, this fails:
But that failure has nothing to do with indexing; if you perform the indexing on a statically-sized array, it's fine:
|
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
Could you please share your status/plans regarding the JAX dynamic API, covered by this test? I am particularly interested in array indexing operations.
To give some context, I am exploring options for the dynamic array support in the Catalyst quantum compiler where we use jax as a frontend and a modified StableHLO lowering pipeline as a backend. I checked the JAX 0.4.19 against some common indexing use-cases, see the table below. Can you confirm the correctness of the approach in general? Should we expect improvements in this area in the near future?
arr[42]
arr[idx]
arr[0,:,0]
arr[0:42]
arr[0:idx]
jnp.ones([4,4])[0]
a custom runtime
All the use-cases I tried are of the following form:
Where:
jaxpr_to_mlir
is our thin wrapper around lower_jaxpr_to_funI think it is still representative.
compile_and_run_mlir
runs our customized StableHLO lowering to LLVM. It might contain unrelatedissues, I am showing it here for reference.
0.4.19
, commit 604dc24 with a few hacksUpdates
arr[:]
case from the table due to an error in test caseBeta Was this translation helpful? Give feedback.
All reactions