-
Is there a way for different "slices" (i.e. elements over which are being vmap-ed) can see each other at all? I know there are ways to modify/control vmap from inside a vmap-ed function (e.g. https://github.com/patrick-kidger/equinox/blob/main/equinox/_unvmap.py), but I don't see clearly how this would map to acquiring information across batches and propagating that. For example def f(x):
return x + unvmap(x).mean()
jax.vmap(f)(jnp.arange(10)) now this example is trivial of course (and could be done without an unvmap) but would be a smaller subroutine that could not be so easily extracted (for those more interested, the motivation is patrick-kidger/diffrax#481), but shows the main goal. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
This kind of communication is supported in the specific case of aggregating across the batch axis, via named collectives. Modifying your example a bit: import jax
import jax.numpy as jnp
def f(x):
return x + x.mean('i')
x = jnp.arange(5)
jax.vmap(f, axis_name='i')(x)
# Array([2., 3., 4., 5., 6.], dtype=float32) |
Beta Was this translation helpful? Give feedback.
This kind of communication is supported in the specific case of aggregating across the batch axis, via named collectives. Modifying your example a bit: