-
consider working with a colab provided 8 device TPU. the pmap cookbook shows an example of doing parallel work with pmap by placing data across devices and then doing some calculation; e.g.
if we try to do this with a leading dimension that's not equal to the number of devices we get an error
if the leading dim is a multiple of the number of devices we can get around this by 1) injecting a dummy axis in and 2) further mapping across the dummy axis with a vmap inside the pmap 3) reshaping the final result to get rid of the dummy axis
this works for my use case, just wanted to make sure i wasn't missing something about pmap that could do this for me? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
This is an annoying wart of However, I suggested you ask about this on GitHub because there is another (older) prototype you could try: from jax.config import config
config.enable_omnistaging() In terms of your example, you should be able to write this (where I effectively merged the first two lines): import jax.numpy as jnp
from jax import random
from jax import soft_pmap
from jax.config import config
config.enable_omnistaging()
keys = random.split(random.PRNGKey(0), 16)
matrices = soft_pmap(lambda k: random.normal(k, (100, 200)))(keys)
result = soft_pmap(lambda x: jnp.dot(x, x.T))(matrices)
print(result.shape) # (16, 100, 100)
print(type(result)) # <class 'jax.interpreters.pxla.ShardedDeviceArray'> |
Beta Was this translation helpful? Give feedback.
-
that's perfect! i'll switch to |
Beta Was this translation helpful? Give feedback.
Twitter context.
This is an annoying wart of
pmap
, which we hope to revise soon! We have a prototype replacement checked in, calledgmap
(from #4006), which will allow schedulable maps, so that you can control how the map is evaluated as a combination of parallelization, vectorization, and iteration (like your manual pmap+vmap, but without requiring the reshape, and without requiring you to have two separate axis names). But while that's the long-term solution, it's not ready yet (in particular because it doesn't work efficiently with ShardedDeviceArrays). (cc @apaszke )However, I suggested you ask about this on GitHub because there is another (older) prototype you could try:
soft_pmap
(…