Skip to content

Commit 4aa9fde

Browse files
wwwjntianyu-l
andauthored
Add debugging instructions for "Reproducibility between runs" (#1363)
As titled. To avoid someone with the same issue as me to understand why we need "seed checkpoint" --------- Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
1 parent a04f6bd commit 4aa9fde

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

docs/debugging.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,61 @@ This will print a structured configuration to `stdout`, allowing you to verify t
5959
If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs.
6060
When a job times out, Flight Recorder automatically generates dump files on every rank containing valuable debugging data. You can find these dump files in the `job.dump_folder` directory.
6161
To learn how to analyze and diagnose issues using these logs, follow our step-by-step tutorial [link](https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html).
62+
63+
64+
65+
## Reproducibility between Runs
66+
67+
When debugging issues with multi-dimensional parallelism (combinations of FSDP, TP, PP, CP, EP), ensuring reproducible behavior is crucial for isolating and fixing problems. `torchtitan` provides several mechanisms to achieve deterministic training runs. For more information on ensuring reproducibility and managing randomness in PyTorch, you can refer to the official PyTorch documentation on randomness: [PyTorch Randomness Documentation](https://docs.pytorch.org/docs/stable/notes/randomness.html).
68+
69+
### Seed Configuration
70+
Set consistent random seeds across all parallelism dimensions:
71+
72+
```bash
73+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42
74+
```
75+
76+
**Seed behavior with parallelism:**
77+
- **Data Parallel (DP/FSDP), Tensor Parallel (TP), Context Parallel (CP):** All ranks use the same seed.
78+
- Note: For FSDP and TP, DTensor will do special RNG management to make sure a Replicate tensor get the same init across ranks, but a Shard tensor get "random"-like init across ranks.
79+
- **Pipeline Parallel (PP):** Each PP stage gets a different seed to ensure different initialization across layers on different PP ranks.
80+
81+
82+
### Deterministic Mode
83+
84+
Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs:
85+
86+
```bash
87+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic
88+
```
89+
90+
**What it does:**
91+
- Forces all CUDA operations to use deterministic algorithms
92+
- Disables CuDNN benchmarking and enables deterministic mode
93+
- Sets deterministic workspace configuration for CuBLAS operations
94+
- **Note:** This will significantly reduce training performance but ensures exact reproducibility
95+
96+
97+
### Seed-Checkpoint-based Reproducibility
98+
99+
For multiple experimental runs with different parallelism configs, we need to use a "seed" checkpoint to ensure model initializations are the same across runs. This is because in `torchtitan/train.py`, the model parameters are sharded first, and then have their weights initialized on each rank separately. As a result, it is not equivalent to initialize the model on one rank and then shard it. Using a seed checkpoint helps different runs load the same model weights from checkpoint -- DCP resharding will make sure the loaded weights are sharded correctly according to the parallelism configs.
100+
101+
102+
```bash
103+
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
104+
```
105+
106+
**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.
107+
108+
### Example: Reproducing loss curves with different parallelism configs
109+
110+
A common scenario is when you introduce a new parallelism strategy to the model, you need to ensure that the loss curve remains numerically equivalent to the previous parallelism config, thereby confirming the accuracy of your implementation. To achieve consistent behavior across multiple runs with varying parallelism configurations, it's crucial to make sure dataloader behaves consistently. We need to fix the DP degree (`dp_replicate * dpshard`) to ensure the dataloader operates consistently.
111+
112+
Here's a typical comparison setup (maintaining an overall DP degree of 4):
113+
- Run 1: dp_shard = 4
114+
- Run 2: dp_replicate = 2, dp_shard = 2, TP degree = 2
115+
- Run 3: dp_replicate = 2, dp_shard = 2, CP degree = 2, PP degree = 2
116+
117+
To reproduce loss curves across above runs, you'll need to create a seed checkpoint, and then load the same seed checkpoint for all runs to ensure consistent model initialization on each rank. You might also need to set the `deterministic` mode to ensure consistent training behavior.
118+
119+
We also provided an example of verifying the numerical consistency across parallism plans configs on Llama 3 in https://github.com/pytorch/torchtitan/blob/main/docs/converging.md.

0 commit comments

Comments
 (0)