-
Hello, I want to append
This is wrapped in a trajectories = []
# .... code ...
tmp_traj.append(tmp)
is_event = is_event == 1
def handle_event(trajectories, tmp_traj):
def true_fn(state):
trajectories, tmp_traj = state
trajectories.append(tmp_traj)
tmp_traj = []
return trajectories, tmp_traj
def false_fn(state):
trajectories, tmp_traj = state
return trajectories, tmp_traj
return jax.lax.cond(
is_event,
true_fn,
false_fn,
(trajectories, tmp_traj),
)
trajectories, tmp_traj = handle_event(
trajectories, tmp_traj
) When running this, I get the following error:
From what I understand, jax.lax.cond requires exact output tree structures. In this case:
The nesting and order differ, so JAX refuses to compile it. Since my return is dynamic, how can I solve this problem? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Unfortunately, what you're describing is an example of arrays with dynamic shapes (i.e, shapes that depend on values within arrays), and this mode of computation is not suppoted by JAX transformations like Hope that helps! |
Beta Was this translation helpful? Give feedback.
Unfortunately, what you're describing is an example of arrays with dynamic shapes (i.e, shapes that depend on values within arrays), and this mode of computation is not suppoted by JAX transformations like
jit
,vmap
,grad
, etc. (See JAX Sharp Bits: Dynamic Shapes for more discussion). Given this, there is no way to do what you are asking – instead you'll need to figure out how to implement your algorithm in terms of statically shaped arrays. This is not typically too difficult to reformulate computations this way, but the best approach varies from problem to problem. For example, here you might keep a statically-shaped mask over the events array that indicates whether there is a detection…