You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support for saving HF format tensors with DCP (#1351)
If checkpoint.last_save_in_safetensors_format is set, then save the
checkpoint with DCP HF components that will save the checkpoint in
.safetensors files instead of regular DCP format on final save. On load,
we can decide which type of load to do based on checkpoint type.
Successful save:
```
(titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0,1,2
+ LOG_RANK=0,1,2
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774]
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training
[rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training
[rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training
[rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.9
[rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model
[rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921
[rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1.
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1.
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,046 tflops: 60.58 mfu: 19.42%
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 971 tflops: 56.23 mfu: 18.02%
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1
[rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,038 tflops: 60.13 mfu: 19.27%
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758
[rank0]:Done writing data. Took %.2f secs. 66.62590146064758
[rank0]:Done consolidating. Took %.2f secs. 66.62735033035278
[rank0]:time taken for all reduce: 141.72666668891907
[rank1]:time taken for all reduce: 141.73284125328064
[rank2]:time taken for all reduce: 141.72900009155273
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread.
[rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds.
[rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread.
[rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed
[rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread.
[rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed.
```
Successful load:
```
(titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774]
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training
[rank0]:NCCL version 2.27.5+cuda12.9
[rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056
[rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200).
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3.
[rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.)
[rank0]: tensor = torch.frombuffer(
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step: 1 loss: 12.0247 grad_norm: 42.7524 memory: 42.12GiB(53.23%) tps: 236 tflops: 13.67 mfu: 4.38%
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
```
---------
Co-authored-by: ankitageorge <ankitageorge@devvm6863.rva0.facebook.com>
Co-authored-by: ankitageorge <ankitageorge@devvm2888.eag0.facebook.com>
Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
## How to load / save a checkpoint in HF safetensors format
91
+
For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format).
92
+
For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format.
0 commit comments