Skip to content

Commit aae7323

Browse files
authored
support torchft streaming diloco (#1302)
Summary: - update to use the changed api to create diloco (we need to pass in model parts) - add configuration options for streaming diloco Test Plan: ``` $ NGPU=2 ./run_train.sh --fault_tolerance.enable --fault_tolerance.group_size=1 --fault_tolerance.semi_sync_method=diloco --fault_tolerance.sync_steps=2 --fault_tolerance.replica_id=0 --fault_tolerance.fragment_sync_delay=1 --fault_tolerance.fragment_update_alpha=0.0 [rank0]:[titan] 2025-06-16 09:39:08,893 - root - INFO - Model llama3 debugmodel size: 6,270,208 total parameters [rank0]:[titan] 2025-06-16 09:39:08,894 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-06-16 09:39:08,952 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-06-16 09:39:09,375 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-06-16 09:39:09,376 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-06-16 09:39:09,376 - root - INFO - CUDA memory usage for model: 0.03GiB(0.04%) [rank0]:[titan] 2025-06-16 09:39:09,377 - root - INFO - Trainer is initialized with local batch size 8, global batch size 16, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2). [rank0]:[titan] 2025-06-16 09:39:09,377 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 09:39:10,325 - root - INFO - step: 1 loss: 8.1934 memory: 1.26GiB(1.59%) tps: 11,442 tflops: 0.82 mfu: 0.26% [rank0]:[titan] 2025-06-16 09:39:10,325 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 09:39:10,431 - root - INFO - step: 2 loss: 8.1507 memory: 1.35GiB(1.71%) tps: 154,916 tflops: 11.14 mfu: 3.57% [rank0]:[titan] 2025-06-16 09:39:10,524 - root - INFO - step: 3 loss: 8.0737 memory: 1.35GiB(1.71%) tps: 177,405 tflops: 12.76 mfu: 4.09% [rank0]:[titan] 2025-06-16 09:39:10,623 - root - INFO - step: 4 loss: 7.8865 memory: 1.35GiB(1.71%) tps: 167,289 tflops: 12.03 mfu: 3.86% [rank0]:[titan] 2025-06-16 09:39:10,714 - root - INFO - step: 5 loss: 7.7620 memory: 1.35GiB(1.71%) tps: 179,656 tflops: 12.92 mfu: 4.14% [rank0]:[titan] 2025-06-16 09:39:10,808 - root - INFO - step: 6 loss: 7.5449 memory: 1.35GiB(1.71%) tps: 175,901 tflops: 12.65 mfu: 4.05% [rank0]:[titan] 2025-06-16 09:39:10,911 - root - INFO - step: 7 loss: 7.3452 memory: 1.35GiB(1.71%) tps: 159,859 tflops: 11.49 mfu: 3.68% [rank0]:[titan] 2025-06-16 09:39:11,005 - root - INFO - step: 8 loss: 7.2973 memory: 1.35GiB(1.71%) tps: 175,980 tflops: 12.65 mfu: 4.06% [rank0]:[titan] 2025-06-16 09:39:11,096 - root - INFO - step: 9 loss: 7.1333 memory: 1.35GiB(1.71%) tps: 179,903 tflops: 12.94 mfu: 4.15% [rank0]:[titan] 2025-06-16 09:39:11,186 - root - INFO - step: 10 loss: 7.0747 memory: 1.35GiB(1.71%) tps: 184,628 tflops: 13.28 mfu: 4.26% [rank0]:[titan] 2025-06-16 09:39:11,186 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-06-16 09:39:13,186 - root - INFO - Training completed [rank0]:[titan] 2025-06-16 09:39:13,489 - root - INFO - Process group destroyed. ```
1 parent 6b0fd66 commit aae7323

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

torchtitan/components/ft.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,15 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
170170
def maybe_semi_sync_training(
171171
config: JobConfig,
172172
ft_manager: FTManager,
173-
model: torch.nn.Module,
173+
model_parts: list[torch.nn.Module],
174174
optimizer: torch.optim.Optimizer,
175-
sync_every: int,
176175
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
177176
"""
178177
If TorchFT is enabled and the config is set, use semi_sync_method
179178
"""
180-
semi_sync_method = config.fault_tolerance.semi_sync_method
181-
torchft_enabled = config.fault_tolerance.enable
179+
ft_config = config.fault_tolerance
180+
semi_sync_method = ft_config.semi_sync_method
181+
torchft_enabled = ft_config.enable
182182
if torchft_enabled and semi_sync_method is not None:
183183
from torchft import local_sgd
184184

@@ -195,17 +195,21 @@ def maybe_semi_sync_training(
195195

196196
return local_sgd.DiLoCo(
197197
manager=ft_manager._manager,
198-
model=model,
198+
model_fragments=model_parts,
199199
inner_optimizer=optimizer,
200200
outer_optimizer=outer_optimizer,
201-
sync_every=sync_every,
201+
sync_every=ft_config.sync_steps,
202+
should_quantize=ft_config.should_quantize,
203+
fragment_sync_delay=ft_config.fragment_sync_delay,
204+
fragment_update_alpha=ft_config.fragment_update_alpha,
202205
)
203206
elif semi_sync_method.lower() == "local_sgd":
207+
assert len(model_parts) == 1
204208
return local_sgd.LocalSGD(
205209
manager=ft_manager._manager,
206-
model=model,
210+
model=model_parts[0],
207211
optimizer=optimizer,
208-
sync_every=sync_every,
212+
sync_every=ft_config.sync_steps,
209213
)
210214
else:
211215
raise ValueError(

torchtitan/config_manager.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,29 @@ class FaultTolerance:
584584
is set.
585585
"""
586586

587+
should_quantize: bool = False
588+
"""
589+
Whether to quantize the gradients before allreduce.
590+
591+
This is only used when "semi_sync_method" is set.
592+
"""
593+
594+
fragment_sync_delay: int = 0
595+
"""
596+
Controls the number of inner steps to wait before blocking on a
597+
model fragment's synchronization. This is the "tao" parameter in
598+
the Streaming DiLoCo paper.
599+
600+
This is only used when "semi_sync_method" is set.
601+
"""
602+
603+
fragment_update_alpha: float = 0.0
604+
"""
605+
Determines how to mix the local and global optimized parameters
606+
607+
This is only used when "semi_sync_method" is set.
608+
"""
609+
587610

588611
@dataclass
589612
class Experimental:

torchtitan/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,8 @@ def train(self):
480480
ft.maybe_semi_sync_training(
481481
job_config,
482482
ft_manager=self.ft_manager,
483-
model=self.model_parts[0],
483+
model_parts=self.model_parts,
484484
optimizer=self.optimizers,
485-
sync_every=job_config.fault_tolerance.sync_steps,
486485
),
487486
):
488487
data_iterator = self.batch_generator(self.dataloader)

0 commit comments

Comments
 (0)