Replies: 1 comment 5 replies
-
Yes! That is correct since each device gets a unique piece of the data, you shard the batch across all axes of the mesh. But I would recomment using |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
TL;DR: What is the recommended way to use
shard_map
with host local dataloading in muti-host training?I'm trying to migrate from
pmap
to theshard_map
API. First, we create a mesh:We need to define the partition spec for the model and data. I'm using simple data parallelism, so the model sharding is easy:
I'm confused on how to partition host-local data. Since each host loads its own data, I simply want each host to shard along the local device axis:
Surprisingly, this works in my setup (loss matches the
pmap
case). However, I'm curious if I'm relying on some undefined or poorly defined behavior. Given a global mesh, it seems thatP(DEVICE_AXIS)
would imply that the data should be sharded across the device axis, then be replicated across hosts. Obviously this is undesirable since I'm using host-local dataloading.host_local_array_to_global_array
to collect host-local data along the batch dimension, then partition withP((HOST_AXIS, DEVICE_AXIS))
to shard the aggregated batch dimension across devices?Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions