vectorization and parallelization of code #9017
Replies: 1 comment 1 reply
-
from jax import make_jaxpr, vmap
import jax.numpy as jnp
def f(x):
return x ** 2
x = jnp.arange(10)
# implicit vectorization via numpy-style broadcasting
print(make_jaxpr(f)(x))
# { lambda ; a:i32[10]. let b:i32[10] = integer_pow[y=2] a in (b,) }
# explicit vectorization via vmap
print(make_jaxpr(vmap(f))(x))
# { lambda ; a:i32[10]. let b:i32[10] = integer_pow[y=2] a in (b,) } As you can see, the two versions of the function result in the exact same computation being sent to the XLA compiler. I suspect this is similar to what's happening when you wrap your function in
As for taking advantage of in-device parallelism, XLA already does this to an extent, without any explicit action on the user's part. For example, this is one reason why computations can run so much faster on GPU and TPU. I'm not sure the extent to which XLA analogously takes advantage of CPU threads; someone else may be able to answer that. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a 2 dimensional array of size 2621440x4 and I want to vectorize. This translates to an object of the shape [2621440, 4], where the 4 is in the fastest dimension ("C" style representation).
I am applying some computation on each of the elements. The computation is pointwise and embarrassingly parallel. Hence, I specify the vmap in the second dimension, however i do not see any difference in execution time. I try with and without vmap and get the same execution time. I do jitting for both cases.
Also if i want to parallelize this over the threads of a CPU or the sockets of a CPU, will pmap help with that?
Thanks.
Beta Was this translation helpful? Give feedback.
All reactions