How do I shard jax.ops.segment_sum correctly? #16680
-
I am trying to parallelise from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import numpy as np
@partial(jax.jit, static_argnums=(2,))
def ss(y, x, m):
s = jax.ops.segment_sum(y, x, indices_are_sorted=False, num_segments=m)
return s
m,n = int(1024), int(2*1024)
x = jnp.array(np.random.randint(0,m,n))
y = jnp.arange(n, dtype=int)
mesh = mesh_utils.create_device_mesh((2,))
sharding = PositionalSharding(mesh)
x = jax.device_put(x, sharding)
y = jax.device_put(y, sharding)
print(x.shape, x.sharding)
print(y.shape, y.sharding)
s = ss(y, x, m)
print(x.shape, y.shape, (s.shape)) but this gives the following error:
Where am I going wrong? System info:
|
Beta Was this translation helpful? Give feedback.
Answered by
benmoseley
Nov 9, 2023
Replies: 1 comment
-
Update: This works when upgrading to
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
benmoseley
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Update: This works when upgrading to
jax=0.4.20
Output: