-
I've been using vmap for batched versions of some calculations. The input data is fairly large, and the calculations lead to some large intermediate arrays. To avoid out of memory errors, I've subdivided the input into smaller chunks and used jax.lax.map to loop over these chunks, with vmap inside. I noticed that pytorch vmap has a chunksize parameter that handles this without need for an outer loop or map. Is there anything like this available (or planned) in jax? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This is discussed in #11319 – there is talk currently about adding such an API (cc/ @froystig, @shoyer) |
Beta Was this translation helpful? Give feedback.
This is discussed in #11319 – there is talk currently about adding such an API (cc/ @froystig, @shoyer)