Combine two arrays on different devices #14145
Unanswered
mehdiataei
asked this question in
General
Replies: 2 comments 14 replies
-
Another possible solution would be so somehow only access addressable_data in a sharded array and modify it in-place. But I don't think this is possible in JAX currently. You also cannot access addressable_data in a JITed function. |
Beta Was this translation helpful? Give feedback.
0 replies
-
If you want to concatenate along the leading axis, effectively creating a sharded array, I suspect the best approach would be using import jax
import numpy as np
devices = jax.devices()
x = jax.device_put(np.zeros(10), device=devices[0])
y = jax.device_put(np.ones(10), device=devices[1])
xy = jax.device_put_sharded([x, y], devices=devices[:2])
print(xy)
# [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
print(xy.devices())
# [CpuDevice(id=1), CpuDevice(id=0)] |
Beta Was this translation helpful? Give feedback.
14 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I want to concatenate arrays 'a' and 'b' located on devices 0 and 1, respectively, along a specific axis without transferring the data back to a single device (as it would significantly hinder the performance). I am using the new jax.Array and am unsure if device_put_sharded is a viable solution for this task.
Any thoughts?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions