-
Is it possible to use jax.lax.scan to rewrite a for loop across several lists of jax.numpy arrays, where all lists have the same length but the elements of the lists may have different shapes? For example:
What I would want this toy example to do is return 0 + (sum(1, 2, 3) + sum(7, 8)) + (sum(4, 5, 6) + sum(9, 10)) = 21 + 34 = 55. This seems like it is possible since lists are pytrees (correct?), but I keep getting error messages about mismatched first axes dimensions when I try to do this. I don't understand where this error is coming from because both lists arrays1 and arrays2 have the same length (2). Specifically, the error is:
I need my inputs to be lists or something that can handle jagged shapes because in my real work, the elements of the lists do not have consistent shape. I cannot zero pad the arrays to make them consistent shape because that makes the code break (makes some matrices singular). I appreciate any help I can get here! This is the first time I am using jax.lax.scan so I may be just doing something totally wrong. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
arrays1 = jnp.stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])])
arrays2 = jnp.stack([jnp.array([7, 8]), jnp.array([9, 10])])
example_fun((arrays1, arrays2))
# 55 |
Beta Was this translation helpful? Give feedback.
scan
is able to scan over arrays, not over lists. So one way you could accomplish what you want is to stack your inputs into arrays so that you're scanning over the first axis of each: