-
Suppose I have some code like this: def bar(n):
x = 0
def body(i, x):
x += fun(x)
return x
return jax.lax.fori_loop(0, n, body, x)
out = jax.vmap(bar)(np.arange(m)) My question is, how many times will |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The Python function For a vmapped For more background on JAX's trace-based execution model, see How To Think In JAX. |
Beta Was this translation helpful? Give feedback.
fori_loop
with a static bound lowers toscan
, andfori_loop
with a dynamic bound lowers towhile_loop
. Since you are mapping overn
here, it lowers to awhile_loop
.When you
vmap
awhile_loop
over the bound, it results in a singlewhile_loop
over the batched body function. So in the end there will bem
iterations over a vmapped call tof
with a batch size ofm
.Does that answer your question?