Skip to content

jax.lax.scan using lists of arrays #13898

Answered by jakevdp
pvasired asked this question in Q&A
Discussion options

You must be logged in to vote

scan is able to scan over arrays, not over lists. So one way you could accomplish what you want is to stack your inputs into arrays so that you're scanning over the first axis of each:

arrays1 = jnp.stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])])
arrays2 = jnp.stack([jnp.array([7, 8]), jnp.array([9, 10])])

example_fun((arrays1, arrays2))
# 55

Replies: 1 comment 8 replies

Comment options

You must be logged in to vote
8 replies
@jakevdp
Comment options

@smartalecH
Comment options

@jakevdp
Comment options

@MRiabov
Comment options

@jakevdp
Comment options

Answer selected by pvasired
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants