Skip to content

Commit db52d57

Browse files
ankitageorgeankitageorgeankitageorgetianyu-l
authored
Add support for saving HF format tensors with DCP (#1351)
If checkpoint.last_save_in_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type. Successful save: ``` (titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0,1,2 + LOG_RANK=0,1,2 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] ***************************************** W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] ***************************************** [rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training [rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training [rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training [rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4 [rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model [rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921 [rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model [rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1. [rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1. [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,046 tflops: 60.58 mfu: 19.42% [rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1 [rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 971 tflops: 56.23 mfu: 18.02% [rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1 [rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,038 tflops: 60.13 mfu: 19.27% [rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1 [rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2 [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291 [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2 [rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291 [rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2 [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291 [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758 [rank0]:Done writing data. Took %.2f secs. 66.62590146064758 [rank0]:Done consolidating. Took %.2f secs. 66.62735033035278 [rank0]:time taken for all reduce: 141.72666668891907 [rank1]:time taken for all reduce: 141.73284125328064 [rank2]:time taken for all reduce: 141.72900009155273 [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds. [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread. [rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds. [rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds. [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread. [rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed. [rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed [rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread. [rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed. [rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed. ``` Successful load: ``` (titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] ***************************************** W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] ***************************************** [rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056 [rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200). [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3. [rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.) [rank0]: tensor = torch.frombuffer( [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step: 1 loss: 12.0247 grad_norm: 42.7524 memory: 42.12GiB(53.23%) tps: 236 tflops: 13.67 mfu: 4.38% [rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 ``` --------- Co-authored-by: ankitageorge <ankitageorge@devvm6863.rva0.facebook.com> Co-authored-by: ankitageorge <ankitageorge@devvm2888.eag0.facebook.com> Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
1 parent 6204cdf commit db52d57

File tree

6 files changed

+202
-22
lines changed

6 files changed

+202
-22
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ wandb
77
fsspec
88
tyro
99
tokenizers >= 0.15.0
10+
safetensors

docs/checkpoint.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,8 @@ e.g.
8585
```bash
8686
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
8787
```
88+
89+
90+
## How to load / save a checkpoint in HF safetensors format
91+
For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format).
92+
For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format.

tests/integration_tests.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,22 @@ def build_test_list():
118118
"Checkpoint Integration Test - Save Load Full Checkpoint",
119119
"full_checkpoint",
120120
),
121+
OverrideDefinitions(
122+
[
123+
[
124+
"--checkpoint.enable_checkpoint",
125+
"--checkpoint.folder hf_checkpoint",
126+
"--checkpoint.last_save_in_safetensors_format",
127+
"--checkpoint.last_save_model_weights_only",
128+
],
129+
[
130+
"--checkpoint.enable_checkpoint",
131+
"--checkpoint.initial_load_path artifacts-to-be-uploaded/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/",
132+
],
133+
],
134+
"Checkpoint Integration Test - save load full checkpoint in HF safetensors format",
135+
"full_checkpoint_hf_safetensors",
136+
),
121137
OverrideDefinitions(
122138
[
123139
[

tests/unit_tests/test_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def tearDown(self):
144144
shutil.rmtree(self.base_temp_dir)
145145
time.sleep(0.1)
146146

147-
def fake_save(self, state_dict: dict, checkpoint_id: str):
147+
def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
148148
os.makedirs(checkpoint_id, exist_ok=True)
149149
sd_to_save = {}
150150
for key, val in state_dict.items():
@@ -584,7 +584,7 @@ def __init__(self):
584584
@mock.patch("torchtitan.components.checkpoint.dcp.load")
585585
@mock.patch("torchtitan.components.checkpoint.dcp.save")
586586
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
587-
def fake_save(state_dict: dict, checkpoint_id: str):
587+
def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
588588
self.assertIn("bias", state_dict)
589589
self.assertIn("weight", state_dict)
590590
# No model prefix

torchtitan/components/checkpoint.py

Lines changed: 167 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
import shutil
1313
import threading
1414
import time
15+
from concurrent.futures import Future
1516
from typing import Any
1617

1718
import torch
1819
import torch.distributed as dist
1920
import torch.distributed.checkpoint as dcp
2021
import torch.nn as nn
22+
from torch.distributed.checkpoint import (
23+
HuggingFaceStorageReader,
24+
HuggingFaceStorageWriter,
25+
)
2126
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
2227
from torch.distributed.checkpoint.state_dict import (
2328
get_model_state_dict,
@@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum):
4954
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
5055

5156

57+
class CheckpointType(str, enum.Enum):
58+
DCP = "DCP"
59+
SAFETENSORS = "safetensors"
60+
61+
5262
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
5363
# temporarily and we don't want to include it in the exported state_dict.
5464
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
@@ -92,12 +102,6 @@ class SaveDone:
92102
pass
93103

94104

95-
@torch.no_grad()
96-
def save_with_gc(state, checkpoint_id):
97-
dcp.save(state, checkpoint_id=checkpoint_id)
98-
GarbageCollection.collect("GC collection invoked by checkpointer.")
99-
100-
101105
def purge_thread(purge_queue: queue.Queue):
102106
"""Thread to purge the old checkpoints.
103107
@@ -190,6 +194,9 @@ def __init__(
190194
) -> None:
191195
ckpt_config = job_config.checkpoint
192196
self.enable_checkpoint = ckpt_config.enable_checkpoint
197+
self.last_save_in_safetensors_format = (
198+
ckpt_config.last_save_in_safetensors_format
199+
)
193200
self.ft_manager = (
194201
ft_manager.manager if ft_manager and ft_manager.enabled else None
195202
)
@@ -314,6 +321,98 @@ def close(self):
314321
if self.stager is not None:
315322
self.stager.close()
316323

324+
@torch.no_grad()
325+
def dcp_save(
326+
self,
327+
state_dict: dict[str, Any],
328+
checkpoint_id: str,
329+
async_mode: AsyncMode,
330+
enable_garbage_collection: bool = False,
331+
save_in_safetensors_format: bool = False,
332+
) -> Future | None:
333+
"""Save the checkpoint with dcp.
334+
Args:
335+
state_dict (dict): The state dict to save.
336+
checkpoint_id (str): The checkpoint id to save.
337+
async_mode (AsyncMode): Whether the checkpoint is async.
338+
enable_garbage_collection (bool): Whether to enable garbage collection after save.
339+
save_in_safetensors_format (bool): Whether to save in safetensors format.
340+
341+
Returns:
342+
Future: The future object if the checkpoint is async, otherwise None.
343+
"""
344+
345+
ret: Future | None = None
346+
347+
storage_writer: HuggingFaceStorageWriter | None = None
348+
checkpoint_save_id: str | None = None
349+
if save_in_safetensors_format:
350+
fqn_to_index_mapping = {}
351+
num_fqns_per_file = 30
352+
# the use of 30 is just a heuristic for now.
353+
# Once these fqns map to HF ones, we can use the fqn mapping
354+
# from the model.safetensors.index.json file
355+
for i, key in enumerate(state_dict.keys()):
356+
group_num = (i // num_fqns_per_file) + 1
357+
fqn_to_index_mapping[key] = group_num
358+
359+
storage_writer = HuggingFaceStorageWriter(
360+
path=checkpoint_id,
361+
save_distributed=True,
362+
fqn_to_index_mapping=fqn_to_index_mapping,
363+
enable_consolidation=True,
364+
thread_count_consolidation=5,
365+
)
366+
else:
367+
checkpoint_save_id = checkpoint_id
368+
369+
if async_mode == AsyncMode.ASYNC:
370+
ret = dcp.async_save(
371+
state_dict,
372+
storage_writer=storage_writer,
373+
checkpoint_id=checkpoint_save_id,
374+
process_group=self.pg,
375+
)
376+
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
377+
ret = dcp.async_save(
378+
state_dict,
379+
storage_writer=storage_writer,
380+
checkpoint_id=checkpoint_save_id,
381+
process_group=self.pg,
382+
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
383+
async_stager=self.stager,
384+
)
385+
else:
386+
ret = dcp.save(
387+
state_dict,
388+
storage_writer=storage_writer,
389+
checkpoint_id=checkpoint_save_id,
390+
)
391+
392+
if enable_garbage_collection:
393+
GarbageCollection.collect("GC collection invoked by checkpointer.")
394+
395+
return ret
396+
397+
def dcp_load(
398+
self,
399+
state_dict: dict[str, Any],
400+
checkpoint_id: str,
401+
checkpoint_type: CheckpointType,
402+
) -> None:
403+
"""Load the checkpoint with dcp.
404+
Args:
405+
state_dict (dict): The state dict to load.
406+
checkpoint_id (str): The checkpoint id to load.
407+
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
408+
"""
409+
410+
if checkpoint_type == CheckpointType.SAFETENSORS:
411+
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
412+
dcp.load(state_dict, storage_reader=storage_reader)
413+
else:
414+
dcp.load(state_dict, checkpoint_id=checkpoint_id)
415+
317416
@torch.no_grad()
318417
def save(self, curr_step: int, last_step: bool = False) -> None:
319418
"""Save the checkpoint for the current step.
@@ -354,23 +453,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
354453
GarbageCollection.collect("GC collection invoked by checkpointer.")
355454
if self.stager is None:
356455
self.stager = DefaultStager(StagingOptions(True, True, True, True))
357-
result = dcp.async_save(
456+
result = self.dcp_save(
358457
states,
359458
checkpoint_id=checkpoint_id,
360-
process_group=self.pg,
361-
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
362-
async_stager=self.stager,
459+
async_mode=self.async_mode,
363460
)
364461
self.save_future = result.upload_completion
365462
self.staging_future = result.staging_completion
366463
elif self.async_mode == AsyncMode.ASYNC:
367464
GarbageCollection.collect("GC collection invoked by checkpointer.")
368-
self.save_future = dcp.async_save(
369-
states, checkpoint_id=checkpoint_id, process_group=self.pg
465+
self.save_future = self.dcp_save(
466+
states, checkpoint_id=checkpoint_id, async_mode=self.async_mode
370467
)
371468
GarbageCollection.collect("GC collection invoked by checkpointer.")
372469
else:
373-
save_with_gc(states, checkpoint_id=checkpoint_id)
470+
self.dcp_save(
471+
states,
472+
checkpoint_id=checkpoint_id,
473+
async_mode=AsyncMode.DISABLED,
474+
enable_garbage_collection=True,
475+
)
374476
self._purge_stale_checkpoints()
375477

376478
logger.info(
@@ -432,10 +534,19 @@ def load(self, step: int = -1) -> bool:
432534
f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found."
433535
)
434536

537+
checkpoint_type = self._find_checkpoint_type(checkpoint_id)
538+
if checkpoint_type == CheckpointType.SAFETENSORS:
539+
assert (
540+
model_only
541+
), "Only model weights can be loaded when loading from safetensors checkpoint."
435542
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
436543
begin = time.monotonic()
437544
states = self._states_to_load(model_only)
438-
dcp.load(states, checkpoint_id=checkpoint_id)
545+
self.dcp_load(
546+
states,
547+
checkpoint_id=checkpoint_id,
548+
checkpoint_type=checkpoint_type,
549+
)
439550
GarbageCollection.collect("GC collection for checkpoint loading.")
440551
logger.info(
441552
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
@@ -470,13 +581,33 @@ def _find_load_step(self, folder: str = "") -> int:
470581

471582
for filename in os.listdir(folder):
472583
match = re.search(pattern, filename)
473-
metadata_probe = os.path.join(folder, filename, ".metadata")
474-
if match and os.path.isfile(metadata_probe):
584+
dcp_metadata_probe = os.path.join(folder, filename, ".metadata")
585+
safetensors_metadata_probe = os.path.join(
586+
folder, filename, "model.safetensors.index.json"
587+
)
588+
if match and os.path.isfile(dcp_metadata_probe):
589+
step_counts.append(int(match.group(1)))
590+
elif match and os.path.isfile(safetensors_metadata_probe):
475591
step_counts.append(int(match.group(1)))
476592
if not step_counts:
477593
return -1
478594
return max(step_counts)
479595

596+
def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType:
597+
"""Find the checkpoint type for the given id.
598+
599+
Args:
600+
checkpoint_id (str): The folder to find the checkpoint type for.
601+
602+
Returns:
603+
CheckpointType: The checkpoint type for the given folder.
604+
"""
605+
606+
for filename in os.listdir(checkpoint_id):
607+
if filename == "model.safetensors.index.json":
608+
return CheckpointType.SAFETENSORS
609+
return CheckpointType.DCP
610+
480611
def _ft_folder(self) -> str:
481612
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
482613

@@ -488,8 +619,8 @@ def _ft_save(self, step: int) -> None:
488619
begin = time.monotonic()
489620
self._async_wait()
490621
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
491-
self.save_future = dcp.async_save(
492-
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
622+
self.save_future = self.dcp_save(
623+
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC
493624
)
494625
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
495626

@@ -501,7 +632,12 @@ def _ft_load(self) -> None:
501632
begin = time.monotonic()
502633
logger.info(f"Loading the FT checkpoint at step {step}.")
503634
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
504-
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
635+
self.dcp_load(
636+
self.ft_states,
637+
checkpoint_id=checkpoint_id,
638+
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
639+
checkpoint_type=CheckpointType.DCP,
640+
)
505641
GarbageCollection.collect("GC collection for checkpoint loading.")
506642
logger.info(
507643
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
@@ -570,7 +706,18 @@ def _save_last_step(self, curr_step: int) -> None:
570706
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
571707
states = self._flattened_model_states_sd()
572708

573-
save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))
709+
if self.last_save_in_safetensors_format:
710+
assert (
711+
self.last_save_model_weights_only
712+
), "Only model weights can be saved when saving in safetensors format."
713+
714+
self.dcp_save(
715+
states,
716+
checkpoint_id=self._create_checkpoint_id(curr_step),
717+
async_mode=AsyncMode.DISABLED,
718+
enable_garbage_collection=True,
719+
save_in_safetensors_format=self.last_save_in_safetensors_format,
720+
)
574721

575722
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
576723
if not self.enable_checkpoint:

torchtitan/config_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,17 @@ class Checkpoint:
475475
for many steps or checkpointing too frequently. The default value is False.
476476
"""
477477

478+
last_save_in_safetensors_format: bool = False
479+
"""
480+
Enable the use of safetensors format for checkpointing. This will save the final checkpoints
481+
in safetensors format instead of the default DCP format. There will be a performance
482+
cost in using this as we need to consolidate the sharded tensors to full tensors as
483+
a separate step. last_save_model_weights_only must be true because safetensors doesn't
484+
support saving non tensors. On load, this argument isn't needed as we will detect
485+
whether the loaded checkpoint is in safetensors format or not.
486+
The default value is False.
487+
"""
488+
478489

479490
@dataclass
480491
class ActivationCheckpoint:

0 commit comments

Comments
 (0)