-
Hi All, I'm trying to use JAX to sum over a large array (approximately 1 trillion elements long, which I can't fully materialize) and I'm not entirely sure how to utilize My ideal solution would be to iteratively; create a subset of this very long array, compute all terms in the subset via I did read through https://jax.readthedocs.io/en/latest/sharded-computation.html , but I believe that requires fully materializing the input data if I'm not mistaken and as my data comes I've attached a minimal reproducible example to better explain what I'm planning to do below, and any help would be greatly appreciated,
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
So, for anyone who reads this in the future, I managed to find a way to resolve this issue. For my case, I need to materialize a large array with |
Beta Was this translation helpful? Give feedback.
So, for anyone who reads this in the future, I managed to find a way to resolve this issue. For my case, I need to materialize a large array with
jnp.arange
, what I did to solve this is, instead of materializing the entire array I only materialize the indices of the subsets of the array, which leads to an array ofM+1
elements forM
subsets and I can define thestart
andstop
indices of each subset as elementm
andm+1
for all values in the indices array. For the ConcretizationError, I usedjax.lax.iota
which seems to handle traced values when creating ajnp.arange
-like object.