Skip to content

Different autograd behaviors between lax.scan and for-loop #16819

Closed Answered by anh-tong
moskomule asked this question in Q&A
Discussion options

You must be logged in to vote

Hi,

I think the main problem here is that the code scans over list. samples is a list of array, jax.lax.scan does not scan over a list (see more #13898). Instead, it looks at each element in the list and scan over them individually. The code can obtain the expected result by changing

(inner, state), _ = jax.lax.scan(f, (inner, state), samples[:-1])

to

(inner, state), _ = jax.lax.scan(f, (inner, state), jnp.stack(samples)[:-1])

Here is the colab for reference.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@moskomule
Comment options

Answer selected by moskomule
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants