Skip to content

Commit 5e09057

Browse files
authored
[llama4]: flatten dp_mesh to enable data parallel replicate (#1279)
Oneliner fix to enable training with HSDP. When DP Replicate is enabled, dp_mesh contains `(dp_shard_cp, dp_replicate)`, and `dp_mesh.get_group()` will throw an error due to multiple axes. So I flatten it the same way it is done in `build_device_mesh()` function, with the same name just to be explicit). Note: I should be able to use `world_mesh['dp_cp']` here, as it's already flatten when world_mesh is built. However inside the function when it is executed, this flatten mesh is not available for some reason.
1 parent 6bffbfc commit 5e09057

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def dp_shard_enabled(self):
120120
def cp_enabled(self):
121121
return self.cp > 1
122122

123+
@property
124+
def dp_cp_enabled(self):
125+
return self.dp_enabled or self.cp_enabled
126+
123127
@property
124128
def tp_enabled(self):
125129
return self.tp > 1

torchtitan/experiments/llama4/infra/parallelize_llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,15 @@ def parallelize_llama(
118118
)
119119

120120
# for MoE auxiliary-loss-free load balancing
121-
if dp_mesh is not None:
121+
if parallel_dims.dp_cp_enabled is not None:
122122
# NOTE: Currently this sync is blocking (thus exposed) and happens on the
123123
# default compute stream. Need to assess if this is OK performance-wise.
124+
dp_cp_mesh = world_mesh["dp_cp"]
125+
124126
def _sync_tokens_per_expert(module, *_):
125127
assert isinstance(module, MoE)
126128
torch.distributed.all_reduce(
127-
module.tokens_per_expert, group=dp_mesh.get_group()
129+
module.tokens_per_expert, group=dp_cp_mesh.get_group()
128130
)
129131

130132
for transformer_block in model.layers.values():

torchtitan/train.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -442,12 +442,7 @@ def train_step(
442442
if not self.metrics_processor.should_log(self.step):
443443
return
444444

445-
if (
446-
parallel_dims.dp_replicate_enabled
447-
or parallel_dims.dp_shard_enabled
448-
or parallel_dims.cp_enabled
449-
or self.ft_manager.enabled
450-
):
445+
if parallel_dims.dp_cp_enabled or self.ft_manager.enabled:
451446
loss = loss.detach()
452447
# Skip ft manager communication when using semi sync training
453448
use_ft_pg = (

0 commit comments

Comments
 (0)