Skip to content

lax.cond: true_fn and false_fn output must have same type structure #30171

Answered by jakevdp
Theo-Fan asked this question in Q&A
Discussion options

You must be logged in to vote

Unfortunately, what you're describing is an example of arrays with dynamic shapes (i.e, shapes that depend on values within arrays), and this mode of computation is not suppoted by JAX transformations like jit, vmap, grad, etc. (See JAX Sharp Bits: Dynamic Shapes for more discussion). Given this, there is no way to do what you are asking – instead you'll need to figure out how to implement your algorithm in terms of statically shaped arrays. This is not typically too difficult to reformulate computations this way, but the best approach varies from problem to problem. For example, here you might keep a statically-shaped mask over the events array that indicates whether there is a detection…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Theo-Fan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants