Skip to content

Commit 1ab4353

Browse files
authored
document default parameters for streaming diloco (#1308)
Summary: document why default parameters are set the way they are for streaming diloco Test Plan: ``` $ NGPU=2 ./run_train.sh --fault_tolerance.enable --fault_tolerance.group_size=1 --fault_tolerance.semi_sync_method=diloco --fault_tolerance.sync_steps=2 --fault_tolerance.replica_id=0 --fault_tolerance.fragment_sync_delay=1 --fault_tolerance.fragment_update_alpha=0.0 [rank0]:[titan] 2025-06-16 09:39:08,893 - root - INFO - Model llama3 debugmodel size: 6,270,208 total parameters [rank0]:[titan] 2025-06-16 09:39:08,894 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-06-16 09:39:08,952 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-06-16 09:39:09,375 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-06-16 09:39:09,376 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-06-16 09:39:09,376 - root - INFO - CUDA memory usage for model: 0.03GiB(0.04%) [rank0]:[titan] 2025-06-16 09:39:09,377 - root - INFO - Trainer is initialized with local batch size 8, global batch size 16, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2). [rank0]:[titan] 2025-06-16 09:39:09,377 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 09:39:10,325 - root - INFO - step: 1 loss: 8.1934 memory: 1.26GiB(1.59%) tps: 11,442 tflops: 0.82 mfu: 0.26% [rank0]:[titan] 2025-06-16 09:39:10,325 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 09:39:10,431 - root - INFO - step: 2 loss: 8.1507 memory: 1.35GiB(1.71%) tps: 154,916 tflops: 11.14 mfu: 3.57% [rank0]:[titan] 2025-06-16 09:39:10,524 - root - INFO - step: 3 loss: 8.0737 memory: 1.35GiB(1.71%) tps: 177,405 tflops: 12.76 mfu: 4.09% [rank0]:[titan] 2025-06-16 09:39:10,623 - root - INFO - step: 4 loss: 7.8865 memory: 1.35GiB(1.71%) tps: 167,289 tflops: 12.03 mfu: 3.86% [rank0]:[titan] 2025-06-16 09:39:10,714 - root - INFO - step: 5 loss: 7.7620 memory: 1.35GiB(1.71%) tps: 179,656 tflops: 12.92 mfu: 4.14% [rank0]:[titan] 2025-06-16 09:39:10,808 - root - INFO - step: 6 loss: 7.5449 memory: 1.35GiB(1.71%) tps: 175,901 tflops: 12.65 mfu: 4.05% [rank0]:[titan] 2025-06-16 09:39:10,911 - root - INFO - step: 7 loss: 7.3452 memory: 1.35GiB(1.71%) tps: 159,859 tflops: 11.49 mfu: 3.68% [rank0]:[titan] 2025-06-16 09:39:11,005 - root - INFO - step: 8 loss: 7.2973 memory: 1.35GiB(1.71%) tps: 175,980 tflops: 12.65 mfu: 4.06% [rank0]:[titan] 2025-06-16 09:39:11,096 - root - INFO - step: 9 loss: 7.1333 memory: 1.35GiB(1.71%) tps: 179,903 tflops: 12.94 mfu: 4.15% [rank0]:[titan] 2025-06-16 09:39:11,186 - root - INFO - step: 10 loss: 7.0747 memory: 1.35GiB(1.71%) tps: 184,628 tflops: 13.28 mfu: 4.26% [rank0]:[titan] 2025-06-16 09:39:11,186 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-06-16 09:39:13,186 - root - INFO - Training completed [rank0]:[titan] 2025-06-16 09:39:13,489 - root - INFO - Process group destroyed. ```
1 parent aae7323 commit 1ab4353

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

torchtitan/config_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,11 @@ class FaultTolerance:
588588
"""
589589
Whether to quantize the gradients before allreduce.
590590
591+
Disabled by default since the quantization does utilize the GPU
592+
and uses more collectives. Enabling this requires knowing about
593+
the tradeoffs between GPU utilization and communication.
594+
595+
591596
This is only used when "semi_sync_method" is set.
592597
"""
593598

@@ -597,13 +602,22 @@ class FaultTolerance:
597602
model fragment's synchronization. This is the "tao" parameter in
598603
the Streaming DiLoCo paper.
599604
605+
By default, each model fragment will be synced at the same step
606+
at which the allreduce is issued. Enabling delay can improve
607+
communication and computation overlap, but at the cost of compromising
608+
model quality
609+
600610
This is only used when "semi_sync_method" is set.
601611
"""
602612

603613
fragment_update_alpha: float = 0.0
604614
"""
605615
Determines how to mix the local and global optimized parameters
606616
617+
By default, we just use the global parameters. This ensures all
618+
DDP replicas have the same parameters after syncrhonizing on
619+
the fragment. Tuning this can also affect the model quality.
620+
607621
This is only used when "semi_sync_method" is set.
608622
"""
609623

0 commit comments

Comments
 (0)