Using jax.pmap
inside jax.lax.scan
for multicore computation of the loop
#18477
-
Hello, Apologies if this is a naive question, but I am trying to explore faster ways to compile and execute a jit-compiled function, that requires O(100k) repeated calls. In order to perform faster, multi-core computations I am trying to use
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hello, I think I was needlessly confusing myself. With
|
Beta Was this translation helpful? Give feedback.
Hello,
I think I was needlessly confusing myself. With
pmap
, I dont need to uselax.scan
. The problem is resolved simply by pmapping the interpolation functions: