Skip to content

How do I shard jax.ops.segment_sum correctly? #16680

Closed Answered by benmoseley
benmoseley asked this question in Q&A
Discussion options

You must be logged in to vote

Update: This works when upgrading to jax=0.4.20
Output:

(2048,) PositionalSharding([{GPU 0} {GPU 1}])
(2048,) PositionalSharding([{GPU 0} {GPU 1}])
(2048,) (2048,) (1024,)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by benmoseley
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant