-
There are two different methods for sharding an array, namely import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
import numpy as np
devices = jax.devices()
a0 = jnp.zeros((1024, 2048))
devices_a1 = mesh_utils.create_device_mesh((1, 8))
devices_a2 = np.array(devices).reshape(1, 8)
sharding_a1 = NamedSharding(mesh=Mesh(devices_a1, ('a', 'b')), spec=P('a', 'b'))
sharding_a2 = NamedSharding(mesh=Mesh(devices_a2, ('a', 'b')), spec=P('a', 'b'))
a1 = jax.device_put(a0, sharding_a1)
a2 = jax.device_put(a0, sharding_a2)
jax.debug.visualize_array_sharding(a1)
jax.debug.visualize_array_sharding(a2) I originally thought that these 2 methods would yield the same results, but it turns out that the order of devices is different: Given the results are different, which sharding method should I use? Are there any performance differences between these 2 methods when performing parallel computations? Possibly related to #19661. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The |
Beta Was this translation helpful? Give feedback.
The
mesh_utils
version takes into account the hardware geometry, and will result in a mesh with a more efficient layout for the particular hardware you are running on. For this reason you should usemesh_utils
rather than constructing a mesh manually from the ordered list of devices.