Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ wandb
fsspec
tyro
tokenizers >= 0.15.0
safetensors
5 changes: 5 additions & 0 deletions docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,8 @@ e.g.
```bash
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```


## How to load / save a checkpoint in HF safetensors format
For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format).
For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format.
16 changes: 16 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ def build_test_list():
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.folder hf_checkpoint",
"--checkpoint.last_save_in_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.initial_load_path artifacts-to-be-uploaded/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/",
],
],
"Checkpoint Integration Test - save load full checkpoint in HF safetensors format",
"full_checkpoint_hf_safetensors",
),
OverrideDefinitions(
[
[
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def tearDown(self):
shutil.rmtree(self.base_temp_dir)
time.sleep(0.1)

def fake_save(self, state_dict: dict, checkpoint_id: str):
def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
os.makedirs(checkpoint_id, exist_ok=True)
sd_to_save = {}
for key, val in state_dict.items():
Expand Down Expand Up @@ -584,7 +584,7 @@ def __init__(self):
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
def fake_save(state_dict: dict, checkpoint_id: str):
def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
Expand Down
187 changes: 167 additions & 20 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
import shutil
import threading
import time
from concurrent.futures import Future
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
Expand Down Expand Up @@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum):
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


class CheckpointType(str, enum.Enum):
DCP = "DCP"
SAFETENSORS = "safetensors"


# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
Expand Down Expand Up @@ -92,12 +102,6 @@ class SaveDone:
pass


@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection invoked by checkpointer.")


def purge_thread(purge_queue: queue.Queue):
"""Thread to purge the old checkpoints.
Expand Down Expand Up @@ -190,6 +194,9 @@ def __init__(
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.last_save_in_safetensors_format = (
ckpt_config.last_save_in_safetensors_format
)
self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
)
Expand Down Expand Up @@ -314,6 +321,98 @@ def close(self):
if self.stager is not None:
self.stager.close()

@torch.no_grad()
def dcp_save(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
async_mode: AsyncMode,
enable_garbage_collection: bool = False,
save_in_safetensors_format: bool = False,
) -> Future | None:
"""Save the checkpoint with dcp.
Args:
state_dict (dict): The state dict to save.
checkpoint_id (str): The checkpoint id to save.
async_mode (AsyncMode): Whether the checkpoint is async.
enable_garbage_collection (bool): Whether to enable garbage collection after save.
save_in_safetensors_format (bool): Whether to save in safetensors format.
Returns:
Future: The future object if the checkpoint is async, otherwise None.
"""

ret: Future | None = None

storage_writer: HuggingFaceStorageWriter | None = None
checkpoint_save_id: str | None = None
if save_in_safetensors_format:
fqn_to_index_mapping = {}
num_fqns_per_file = 30
# the use of 30 is just a heuristic for now.
# Once these fqns map to HF ones, we can use the fqn mapping
# from the model.safetensors.index.json file
for i, key in enumerate(state_dict.keys()):
group_num = (i // num_fqns_per_file) + 1
fqn_to_index_mapping[key] = group_num

storage_writer = HuggingFaceStorageWriter(
path=checkpoint_id,
save_distributed=True,
fqn_to_index_mapping=fqn_to_index_mapping,
enable_consolidation=True,
thread_count_consolidation=5,
)
else:
checkpoint_save_id = checkpoint_id

if async_mode == AsyncMode.ASYNC:
ret = dcp.async_save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=checkpoint_save_id,
process_group=self.pg,
)
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
ret = dcp.async_save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=checkpoint_save_id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
async_stager=self.stager,
)
else:
ret = dcp.save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=checkpoint_save_id,
)

if enable_garbage_collection:
GarbageCollection.collect("GC collection invoked by checkpointer.")

return ret

def dcp_load(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
checkpoint_type: CheckpointType,
) -> None:
"""Load the checkpoint with dcp.
Args:
state_dict (dict): The state dict to load.
checkpoint_id (str): The checkpoint id to load.
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
"""

if checkpoint_type == CheckpointType.SAFETENSORS:
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
dcp.load(state_dict, storage_reader=storage_reader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question on how to combine this PR with torchtitan model <-> HF model definition conversion mappings, so that we can load / save with HF models and train with torchtitan.

Let's say we have to mappings, torchtitan_to_hf, hf_to_torchtitan, both taking a state dict and convert it to another, e.g. very similar to https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama4/_convert_weights.py

I could imagine how we do this for save -- when HF format is used, we just don't call the current ModelWrapper.state_dict, we call conversion map torchtitan_to_hf on top of it.

How are we supposed to do load in HF definition? ModelWrapper.load_state_dict is only in fault-tolerant path, but seems not used by dcp.load, so how can we do things like fuse / unfuse of tensors after load? I think we need to manually add a function in dcp_load to do the reverse of torchtitan_to_hf inplace, similar to what ModelWrapper.load_state_dict does in set_model_state_dict. This also means hf_to_torchtitan itself is not useful.

cc @fegin @wwwjn @wesleytruong

else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
"""Save the checkpoint for the current step.
Expand Down Expand Up @@ -354,23 +453,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
GarbageCollection.collect("GC collection invoked by checkpointer.")
if self.stager is None:
self.stager = DefaultStager(StagingOptions(True, True, True, True))
result = dcp.async_save(
result = self.dcp_save(
states,
checkpoint_id=checkpoint_id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
async_stager=self.stager,
async_mode=self.async_mode,
)
self.save_future = result.upload_completion
self.staging_future = result.staging_completion
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.save_future = dcp.async_save(
states, checkpoint_id=checkpoint_id, process_group=self.pg
self.save_future = self.dcp_save(
states, checkpoint_id=checkpoint_id, async_mode=self.async_mode
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(states, checkpoint_id=checkpoint_id)
self.dcp_save(
states,
checkpoint_id=checkpoint_id,
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
)
self._purge_stale_checkpoints()

logger.info(
Expand Down Expand Up @@ -432,10 +534,19 @@ def load(self, step: int = -1) -> bool:
f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found."
)

checkpoint_type = self._find_checkpoint_type(checkpoint_id)
if checkpoint_type == CheckpointType.SAFETENSORS:
assert (
model_only
), "Only model weights can be loaded when loading from safetensors checkpoint."
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
dcp.load(states, checkpoint_id=checkpoint_id)
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
checkpoint_type=checkpoint_type,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -470,13 +581,33 @@ def _find_load_step(self, folder: str = "") -> int:

for filename in os.listdir(folder):
match = re.search(pattern, filename)
metadata_probe = os.path.join(folder, filename, ".metadata")
if match and os.path.isfile(metadata_probe):
dcp_metadata_probe = os.path.join(folder, filename, ".metadata")
safetensors_metadata_probe = os.path.join(
folder, filename, "model.safetensors.index.json"
)
if match and os.path.isfile(dcp_metadata_probe):
step_counts.append(int(match.group(1)))
elif match and os.path.isfile(safetensors_metadata_probe):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh for safetensors do we also require step-10, step-20, etc. type of checkpoint structure? If users want to load an existing checkpoint from HF, is the workflow

  1. download safetensors into local folder
  2. either put it in step-0 subfolder, or use the initial_load_path config?

Copy link
Contributor Author

@ankitageorge ankitageorge Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from your other comment, it seems like you think the HF checkpoint in step 10, 20 etc isn't a valid use case, so I will just follow your suggestion

step_counts.append(int(match.group(1)))
if not step_counts:
return -1
return max(step_counts)

def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType:
"""Find the checkpoint type for the given id.
Args:
checkpoint_id (str): The folder to find the checkpoint type for.
Returns:
CheckpointType: The checkpoint type for the given folder.
"""

for filename in os.listdir(checkpoint_id):
if filename == "model.safetensors.index.json":
return CheckpointType.SAFETENSORS
return CheckpointType.DCP

def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

Expand All @@ -488,8 +619,8 @@ def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.save_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
self.save_future = self.dcp_save(
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

Expand All @@ -501,7 +632,12 @@ def _ft_load(self) -> None:
begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
checkpoint_type=CheckpointType.DCP,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -570,7 +706,18 @@ def _save_last_step(self, curr_step: int) -> None:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
states = self._flattened_model_states_sd()

save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))
if self.last_save_in_safetensors_format:
assert (
self.last_save_model_weights_only
), "Only model weights can be saved when saving in safetensors format."

self.dcp_save(
states,
checkpoint_id=self._create_checkpoint_id(curr_step),
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
save_in_safetensors_format=self.last_save_in_safetensors_format,
)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable_checkpoint:
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,17 @@ class Checkpoint:
for many steps or checkpointing too frequently. The default value is False.
"""

last_save_in_safetensors_format: bool = False
"""
Enable the use of safetensors format for checkpointing. This will save the final checkpoints
in safetensors format instead of the default DCP format. There will be a performance
cost in using this as we need to consolidate the sharded tensors to full tensors as
a separate step. last_save_model_weights_only must be true because safetensors doesn't
support saving non tensors. On load, this argument isn't needed as we will detect
whether the loaded checkpoint is in safetensors format or not.
The default value is False.
"""


@dataclass
class ActivationCheckpoint:
Expand Down
Loading