Skip to content

Communicate across data inside a vmap #22974

Closed Answered by jakevdp
lockwo asked this question in Q&A
Discussion options

You must be logged in to vote

This kind of communication is supported in the specific case of aggregating across the batch axis, via named collectives. Modifying your example a bit:

import jax
import jax.numpy as jnp

def f(x):
  return x + x.mean('i')

x = jnp.arange(5)
jax.vmap(f, axis_name='i')(x)
# Array([2., 3., 4., 5., 6.], dtype=float32)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@lockwo
Comment options

@jakevdp
Comment options

Answer selected by lockwo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants