Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 14, 2025
Merged

Add support for saving HF format tensors with DCP #1351

merged 30 commits into from
Jul 14, 2025

Conversation

ankitageorge
Copy link
Contributor

@ankitageorge ankitageorge commented Jun 27, 2025

If checkpoint.enable_save_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type.

Successful save:

(titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0,1,2
+ LOG_RANK=0,1,2
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] 
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training
[rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training
[rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training
[rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.9
[rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model
[rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921
[rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1.
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1.
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 1,046  tflops: 60.58  mfu: 19.42%
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 971  tflops: 56.23  mfu: 18.02%
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1
[rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 1,038  tflops: 60.13  mfu: 19.27%
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758
[rank0]:Done writing data. Took %.2f secs. 66.62590146064758
[rank0]:Done consolidating. Took %.2f secs. 66.62735033035278
[rank0]:time taken for all reduce:  141.72666668891907
[rank1]:time taken for all reduce:  141.73284125328064
[rank2]:time taken for all reduce:  141.72900009155273
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread.
[rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds.
[rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread.
[rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed
[rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread.
[rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed.

Successful load:

(titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] 
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training
[rank0]:NCCL version 2.27.5+cuda12.9
[rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056
[rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200).
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3.
[rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.)
[rank0]:  tensor = torch.frombuffer(
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step:  1  loss: 12.0247  grad_norm: 42.7524  memory: 42.12GiB(53.23%)  tps: 236  tflops: 13.67  mfu: 4.38%
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 27, 2025
@ankitageorge ankitageorge changed the title Dcp hf Add support for saving HF format tensors with DCP Jun 27, 2025
@fegin
Copy link
Contributor

fegin commented Jun 27, 2025

@Saiteja64 This will conflict with your PR.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the logic LGTM, please address comments and ensure that this PR doesn't conflict with the PR from @Saiteja64. Please also add a test result -- save a hf checkpoint and load one back and check the accuracy.

Comment on lines 116 to 130
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should simplify the function as follow

Suggested change
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None
checkpoint_id = checkpoint_id if not hf_safetensors_format else None
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id)

enable_hf_safetensors_format: bool = False
"""
Enable the use of safetensors format for checkpointing. This will save checkpoints
in safetensors format instead of the default DCP format. The default value is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also mention the possible performance penalty? It's not cost free, right?

@ankitageorge ankitageorge marked this pull request as ready for review July 11, 2025 12:41
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! Thanks for working on this! So from the logging, save a llama3 8B model checkpoints as HF format takes ~200s, and load a HF checkpoint needs ~30s, is this correct?

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM, please fix the remaining comments.

Comment on lines 525 to 526
if checkpoint_type == CheckpointType.SAFETENSORS:
model_only = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should assert if model_only is not True, rather than silently change model_only value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how should I change the logic to allow for this? right now if os.path.exists(self.folder), then there is no way for model_only to be True other than when step == 0. self.initial_load_model_weights_only isn't used in this code path either. It's also already silently being changed in line 516 which is why I did it this way, but happy to change in the way you think is best

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is depending on your answer to my question for loading folder structure.
If I'm right in that question -- to load a HF safetensors format, it requires the user to either use initial_load_path or put it in a step-x folder. I think it makes sense to assert model_only == True here,

unless you think it makes sense to put the HF checkpoint in some step-10 / 20 / 50 folder (why? I can't think of such use cases). If that's the case we can modify the model_only logic on line 516 to reflect that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if there is no use case, will follow your advice, I wasn't sure

self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add one more line, "because FT checkpoint currently only save/load dataloader.".

@ankitageorge
Copy link
Contributor Author

Overall LGTM! Thanks for working on this! So from the logging, save a llama3 8B model checkpoints as HF format takes ~200s, and load a HF checkpoint needs ~30s, is this correct?

save is actually faster than that. This run was ~140 seconds, but it was before I added the num_threads argument, so it should be faster now.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I left some suggestions, and a question on the next step of integrating the torchtitan <> HF model conversion.

Enable the use of safetensors format for checkpointing. This will save the final checkpoints
in safetensors format instead of the default DCP format. There will be a performance
cost in using this as we need to consolidate the sharded tensors to full tensors as
a separate step. Last_save_model_weights must be true because safetensors doesn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a separate step. Last_save_model_weights must be true because safetensors doesn't
a separate step. last_save_model_weights_only must be true because safetensors doesn't

@@ -467,6 +467,17 @@ class Checkpoint:
for many steps or checkpointing too frequently. The default value is False.
"""

enable_save_safetensors_format: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit in naming

Suggested change
enable_save_safetensors_format: bool = False
last_save_in_safetensors_format: bool = False

path=checkpoint_id,
save_distributed=True,
fqn_to_index_mapping=fqn_to_index_mapping,
enable_consolidation=is_last_step,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems more readable if

Suggested change
enable_consolidation=is_last_step,
enable_consolidation=True,

)
if match and os.path.isfile(dcp_metadata_probe):
step_counts.append(int(match.group(1)))
elif match and os.path.isfile(safetensors_metadata_probe):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh for safetensors do we also require step-10, step-20, etc. type of checkpoint structure? If users want to load an existing checkpoint from HF, is the workflow

  1. download safetensors into local folder
  2. either put it in step-0 subfolder, or use the initial_load_path config?

Copy link
Contributor Author

@ankitageorge ankitageorge Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from your other comment, it seems like you think the HF checkpoint in step 10, 20 etc isn't a valid use case, so I will just follow your suggestion

Comment on lines 525 to 526
if checkpoint_type == CheckpointType.SAFETENSORS:
model_only = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is depending on your answer to my question for loading folder structure.
If I'm right in that question -- to load a HF safetensors format, it requires the user to either use initial_load_path or put it in a step-x folder. I think it makes sense to assert model_only == True here,

unless you think it makes sense to put the HF checkpoint in some step-10 / 20 / 50 folder (why? I can't think of such use cases). If that's the case we can modify the model_only logic on line 516 to reflect that.

checkpoint_id=self._create_checkpoint_id(curr_step),
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
is_last_step=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we don't pass is_last_step in, since it's only used when HF format matters.
I suggest we pass in save_in_safetensors_format = self.last_save_in_safetensors_format (after config name change).

In the beginning of this function, we need to assert last_save_model_weights_only == True if last_save_in_safetensors_format

Comment on lines 121 to 131
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_save_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint_hf_safetensors",
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to test both save in HF checkpoint and load from HF checkpoint in this test. Since we only support load weight-only initial checkpoint in HF format, I suggest the following for testing purposes only.

In the first run we save a step-10 checkpoint and pretend it to be the initial checkpoint to load; and then in the second run we use initial_load_path to locate the one we saved. Something like the following, but you may have to tweak the paths a bit.

Suggested change
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_save_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint_hf_safetensors",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.folder hf_checkpoint",
"--checkpoint.enable_save_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.initial_load_path outputs/hf_checkpoint/step-10",
],
],
"Checkpoint Integration Test - save load full checkpoint in HF safetensors format",
"full_checkpoint_hf_safetensors",
),


if checkpoint_type == CheckpointType.SAFETENSORS:
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
dcp.load(state_dict, storage_reader=storage_reader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question on how to combine this PR with torchtitan model <-> HF model definition conversion mappings, so that we can load / save with HF models and train with torchtitan.

Let's say we have to mappings, torchtitan_to_hf, hf_to_torchtitan, both taking a state dict and convert it to another, e.g. very similar to https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama4/_convert_weights.py

I could imagine how we do this for save -- when HF format is used, we just don't call the current ModelWrapper.state_dict, we call conversion map torchtitan_to_hf on top of it.

How are we supposed to do load in HF definition? ModelWrapper.load_state_dict is only in fault-tolerant path, but seems not used by dcp.load, so how can we do things like fuse / unfuse of tensors after load? I think we need to manually add a function in dcp_load to do the reverse of torchtitan_to_hf inplace, similar to what ModelWrapper.load_state_dict does in set_model_state_dict. This also means hf_to_torchtitan itself is not useful.

cc @fegin @wwwjn @wesleytruong

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, had one more comment.

Also, would you please add a small section "How to load / save checkpoint in HF safetensor format" in https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md
For save, users need to set --checkpoint.last_save_in_safetensors_format and --checkpoint.last_save_model_weights_only and it only saves the last checkpoint in HF format (intermediate ones are in DCP format).
For load, users need to either put it in step-0 folder if using --checkpoint.folder, or specify --checkpoint.initial_load_path.

Comment on lines 364 to 366
checkpoint_save_id = (
checkpoint_id if not self.last_save_in_safetensors_format else None
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate what this field is for? Can we use save_in_safetensors_format instead of self.last_save_in_safetensors_format? If so let's put declare it together with storage_writer, and set it to checkpoint_id in the else branch.

Currently, if self.last_save_in_safetensors_format==True then checkpoint_save_id will always be None, regardless if it's last step or not.

This may not be covered by test, because with 10 steps we always same in one format but not both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the non-HF case we need to pass in a checkpoint_id because we don't pass in a storage writer and just use the default storage writer and instantiate it with checkpoint_id. With HF case we pass it in as an arg to HFStorgaeWriter and then update it in that class, so we need to pass in an empty checkpoint_id to save

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks a lot for adding this feature!
Please address last nit comment on the .md tutorial.

ankitageorge and others added 2 commits July 14, 2025 16:51
Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
@ankitageorge ankitageorge merged commit db52d57 into main Jul 14, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants