You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[llama4]: flatten dp_mesh to enable data parallel replicate (#1279)
Oneliner fix to enable training with HSDP.
When DP Replicate is enabled, dp_mesh contains `(dp_shard_cp,
dp_replicate)`, and `dp_mesh.get_group()` will throw an error due to
multiple axes.
So I flatten it the same way it is done in `build_device_mesh()`
function, with the same name just to be explicit).
Note: I should be able to use `world_mesh['dp_cp']` here, as it's
already flatten when world_mesh is built. However inside the function
when it is executed, this flatten mesh is not available for some reason.
0 commit comments