Skip to content

How to vectorize and parrallelize over a very long (1e12 samples) array with shard map #24338

Answered by jwtkeeble
jwtkeeble asked this question in Q&A
Discussion options

You must be logged in to vote

So, for anyone who reads this in the future, I managed to find a way to resolve this issue. For my case, I need to materialize a large array with jnp.arange, what I did to solve this is, instead of materializing the entire array I only materialize the indices of the subsets of the array, which leads to an array of M+1 elements for M subsets and I can define the start and stop indices of each subset as element m and m+1 for all values in the indices array. For the ConcretizationError, I used jax.lax.iota which seems to handle traced values when creating a jnp.arange-like object.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jwtkeeble
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant