Skip to content

Jax a for loop that handles values with different shapes #22001

Closed Answered by eadadi
eadadi asked this question in Q&A
Discussion options

You must be logged in to vote

The way I solved this is by using:

jax.tree_util.tree_map(op, traj, keys)

where op is the logic I applied over any value, and keys is the keys of traj that I needed to use also inside that logic

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by eadadi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant