Replies: 1 comment
-
After further investigation, I've managed to achieve the behavior I was looking for using import jax
import numpy as np
import jax.numpy as jnp
import functools
from absl import app
from absl import flags
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
flags.DEFINE_string("server_addr", "", help="server ip addr")
flags.DEFINE_integer("num_hosts", 1, help="num of hosts")
flags.DEFINE_integer("host_idx", 0, help="index of current host")
FLAGS = flags.FLAGS
def f(x):
return x
def main(argv):
jax.distributed.initialize(FLAGS.server_addr, FLAGS.num_hosts, FLAGS.host_idx)
devices = jax.devices()
local_devices = jax.local_devices()
print("host_idx:", FLAGS.host_idx)
print("devices:", devices)
print("local_devices:", local_devices)
mesh = Mesh(np.array(devices), ("i",))
sharding = NamedSharding(mesh, P("i"))
replicated_sharding = NamedSharding(mesh, P())
x = 8 * FLAGS.host_idx + jnp.arange(8)
global_array = multihost_utils.host_local_array_to_global_array(x, mesh, P("i"))
x_s = shard_map(f, mesh, in_specs=P("i"), out_specs=P("i"))(global_array)
print("x:", x)
print(jax.debug.visualize_array_sharding(x))
print("x_s", multihost_utils.process_allgather(x_s))
print(jax.debug.visualize_array_sharding(x_s))
@functools.partial(
shard_map,
mesh=mesh,
in_specs=P("i"),
out_specs=P("i"),
)
def psum_data(data):
return jax.lax.psum(data, "i")
p_sum_out = psum_data(global_array)
print("devices buffers x_s", [shard.data for shard in x_s.addressable_shards])
print(
"devices buffers p_sum_out",
[shard.data for shard in p_sum_out.addressable_shards],
)
print(
"output after taking psum (global_array):",
multihost_utils.process_allgather(p_sum_out),
)
print(
"output after taking psum (local_array):",
multihost_utils.global_array_to_host_local_array(p_sum_out, mesh, P("i")),
)
if __name__ == "__main__":
app.run(main) The key changes that made this work:
While this approach works, I have some additional questions:
Any insights or best practices would be greatly appreciated. Thank you! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone! I'm exploring the differences between
shard_map
andpmap
functionalities in JAX, particularly in a multi-host setting. I've encountered some behavior that I'd like to understand better and potentially find a solution for.Setup: Multi-Host Environment
Consider a setup with 2 hosts, each having 4 devices.
Example 1: Using
pmap
Here's a basic script using
pmap
:With
pmap
, the final output of takingpsum
yields[56 64]
across all shards, as expected.Example 2: Attempting to Use
shard_map
Now, I tried to achieve the same result using
shard_map
:Observed Behavior and Questions
With
shard_map
, I'm getting different outputs for the shards on each host:[12 16]
[44 48]
It seems like the shards are not aware of each other across hosts when using
shard_map
.Questions
shard_map
compatible with the behavior I want, i.e., to perform operations across all devices on all hosts? Or is it better to stick topmap
for this functionality?shard_map
?Any insights or suggestions would be greatly appreciated. Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions