Using jax.pmap/vmap
and jax.lax.switch
to execute a simulation in parallel. Is it an anti-pattern?
#20916
-
I have My question: Is this some form of an anti-pattern? Toy Example Code from typing import Any, Protocol, Tuple
import jax
import jax.numpy as jnp
PyTree = Any
class Generator(Protocol):
def __call__(self, key: jax.Array) -> PyTree:
...
def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
arr = []
for i, l in enumerate(batchsizes):
arr += [i] * l
return jnp.array(arr)
def _distribute_batchsize(batchsize: int) -> Tuple[int, int]:
vmap_size_min = 8
if batchsize <= vmap_size_min:
return 1, batchsize
else:
n_devices = jax.local_device_count()
assert (
batchsize % n_devices
) == 0, f"Your GPU count of {n_devices} does not split batchsize {batchsize}"
vmap_size = int(batchsize / n_devices)
return int(batchsize / vmap_size), vmap_size
def _merge_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
return jax.tree_map(
lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
)
def batch_generators_lazy(
generators: list[Generator],
batchsizes: list[int],
) -> Generator:
"""Create a large generator by stacking multiple generators lazily."""
assert len(generators) == len(batchsizes)
batch_arr = _build_batch_matrix(batchsizes)
bs_total = len(batch_arr)
pmap, vmap = _distribute_batchsize(bs_total)
batch_arr = batch_arr.reshape((pmap, vmap))
@jax.pmap
@jax.vmap
def _generator(key, which_gen: int):
return jax.lax.switch(which_gen, generators, key)
def generator(key):
pmap_vmap_keys = jax.random.split(key, bs_total).reshape((pmap, vmap, 2))
data = _generator(pmap_vmap_keys, batch_arr)
data = _merge_batchsize(data, pmap, vmap)
return data
return generator
def generator_factory(hyperparams) -> Generator:
def generator(key):
# expensive simulation
return dict(X=jnp.array(0.0) + hyperparams)
return generator
M = 4
N_m = 16
generators = [generator_factory(hyperparam) for hyperparam in range(M)]
batchsizes = M * [N_m]
batched_generator = batch_generators_lazy(generators, batchsizes)
batched_generator(jax.random.PRNGKey(1)) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
As nobody else has answered yet I'll just suggest a comment, although I can not test it without a reproducible example. I believe this is due to My solution was to replace return jax.lax.switch(which_gen, generators, key) with list indexing return generators[which_gen](key) This avoids the evaluation of all branches during the switch statement. |
Beta Was this translation helpful? Give feedback.
-
Did you manage to resolve this issue? |
Beta Was this translation helpful? Give feedback.
No, i never found a better/faster solution to this. I ended up just optimizing other parts to avoid having to call this logic too often and accepted that it takes a couple of hours to generate the data.
If i would do it again, i would probably not implement this part in JAX. If you need lots of branching (via e.g.
jax.lax.switch
), then JAX might not be the best option.