using traced array in lax.cond()
#17620
-
Hi, I have this function
which I call in another function
is it possible to use
? Thank you |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! If I understand the question correctly, yes that should work. (But also based only on what you said a simple |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
If I understand the question correctly, yes that should work. (But also based only on what you said a simple
t = jnp.float32(1) if x < 0.5 else jnp.sqrt(jnp.float32(2)/(alpha+lam))
would work too, unless everything is under an outerjit
, in which case you'd need alax.cond
again. That's basically whatlax.cond
is there for: it always works regardless of what tracing/transformations are going on.)