Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ wandb
fsspec
tyro
tokenizers >= 0.15.0
safetensors
11 changes: 11 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def build_test_list():
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_save_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint_hf_safetensors",
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to test both save in HF checkpoint and load from HF checkpoint in this test. Since we only support load weight-only initial checkpoint in HF format, I suggest the following for testing purposes only.

In the first run we save a step-10 checkpoint and pretend it to be the initial checkpoint to load; and then in the second run we use initial_load_path to locate the one we saved. Something like the following, but you may have to tweak the paths a bit.

Suggested change
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_save_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint_hf_safetensors",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.folder hf_checkpoint",
"--checkpoint.enable_save_safetensors_format",
"--checkpoint.last_save_model_weights_only",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.initial_load_path outputs/hf_checkpoint/step-10",
],
],
"Checkpoint Integration Test - save load full checkpoint in HF safetensors format",
"full_checkpoint_hf_safetensors",
),

OverrideDefinitions(
[
[
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def tearDown(self):
shutil.rmtree(self.base_temp_dir)
time.sleep(0.1)

def fake_save(self, state_dict: dict, checkpoint_id: str):
def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
os.makedirs(checkpoint_id, exist_ok=True)
sd_to_save = {}
for key, val in state_dict.items():
Expand Down Expand Up @@ -584,7 +584,7 @@ def __init__(self):
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
def fake_save(state_dict: dict, checkpoint_id: str):
def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
Expand Down
168 changes: 148 additions & 20 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
import shutil
import threading
import time
from concurrent.futures import Future
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
Expand Down Expand Up @@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum):
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


class CheckpointType(str, enum.Enum):
DCP = "DCP"
SAFETENSORS = "safetensors"


# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
Expand Down Expand Up @@ -92,12 +102,6 @@ class SaveDone:
pass


@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection invoked by checkpointer.")


def purge_thread(purge_queue: queue.Queue):
"""Thread to purge the old checkpoints.
Expand Down Expand Up @@ -190,6 +194,7 @@ def __init__(
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.enable_save_safetensors_format = ckpt_config.enable_save_safetensors_format
self.ft_manager = ft_manager.manager if ft_manager.enabled else None

if self.ft_manager:
Expand Down Expand Up @@ -312,6 +317,89 @@ def close(self):
if self.stager is not None:
self.stager.close()

@torch.no_grad()
def dcp_save(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
async_mode: AsyncMode,
enable_garbage_collection: bool = False,
is_last_step: bool = False,
) -> Future | None:
"""Save the checkpoint with dcp.
Args:
state_dict (dict): The state dict to save.
checkpoint_id (str): The checkpoint id to save.
is_async (bool): Whether the checkpoint is async.
Returns:
Future: The future object if the checkpoint is async, otherwise None.
"""

ret: Future | None = None

storage_writer: HuggingFaceStorageWriter | None = None
if self.enable_save_safetensors_format and is_last_step:
fqn_to_index_mapping = {}
num_fqns_per_file = 30
# the use of 30 is just a heuristic for now.
# Once these fqns map to HF ones, we can use the fqn mapping
# from the model.safetensors.index.json file
for i, key in enumerate(state_dict.keys()):
group_num = (i // num_fqns_per_file) + 1
fqn_to_index_mapping[key] = group_num

storage_writer = HuggingFaceStorageWriter(
path=checkpoint_id,
save_distributed=True,
fqn_to_index_mapping=fqn_to_index_mapping,
enable_consolidation=is_last_step,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems more readable if

Suggested change
enable_consolidation=is_last_step,
enable_consolidation=True,

thread_count_consolidation=5,
)
id = checkpoint_id if not self.enable_save_safetensors_format else None
if async_mode == AsyncMode.ASYNC:
ret = dcp.async_save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=id,
process_group=self.pg,
)
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
ret = dcp.async_save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
async_stager=self.stager,
)
else:
ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id)

if enable_garbage_collection:
GarbageCollection.collect("GC collection invoked by checkpointer.")

return ret

def dcp_load(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
checkpoint_type: CheckpointType,
) -> None:
"""Load the checkpoint with dcp.
Args:
state_dict (dict): The state dict to load.
checkpoint_id (str): The checkpoint id to load.
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
"""

if checkpoint_type == CheckpointType.SAFETENSORS:
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
dcp.load(state_dict, storage_reader=storage_reader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question on how to combine this PR with torchtitan model <-> HF model definition conversion mappings, so that we can load / save with HF models and train with torchtitan.

Let's say we have to mappings, torchtitan_to_hf, hf_to_torchtitan, both taking a state dict and convert it to another, e.g. very similar to https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama4/_convert_weights.py

I could imagine how we do this for save -- when HF format is used, we just don't call the current ModelWrapper.state_dict, we call conversion map torchtitan_to_hf on top of it.

How are we supposed to do load in HF definition? ModelWrapper.load_state_dict is only in fault-tolerant path, but seems not used by dcp.load, so how can we do things like fuse / unfuse of tensors after load? I think we need to manually add a function in dcp_load to do the reverse of torchtitan_to_hf inplace, similar to what ModelWrapper.load_state_dict does in set_model_state_dict. This also means hf_to_torchtitan itself is not useful.

cc @fegin @wwwjn @wesleytruong

else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
"""Save the checkpoint for the current step.
Expand Down Expand Up @@ -352,23 +440,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
GarbageCollection.collect("GC collection invoked by checkpointer.")
if self.stager is None:
self.stager = DefaultStager(StagingOptions(True, True, True, True))
result = dcp.async_save(
result = self.dcp_save(
states,
checkpoint_id=checkpoint_id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
async_stager=self.stager,
async_mode=self.async_mode,
)
self.save_future = result.upload_completion
self.staging_future = result.staging_completion
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.save_future = dcp.async_save(
states, checkpoint_id=checkpoint_id, process_group=self.pg
self.save_future = self.dcp_save(
states, checkpoint_id=checkpoint_id, async_mode=self.async_mode
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(states, checkpoint_id=checkpoint_id)
self.dcp_save(
states,
checkpoint_id=checkpoint_id,
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
)
self._purge_stale_checkpoints()

logger.info(
Expand Down Expand Up @@ -430,10 +521,17 @@ def load(self, step: int = -1) -> bool:
f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found."
)

checkpoint_type = self._find_checkpoint_type(checkpoint_id)
if checkpoint_type == CheckpointType.SAFETENSORS:
model_only = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should assert if model_only is not True, rather than silently change model_only value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how should I change the logic to allow for this? right now if os.path.exists(self.folder), then there is no way for model_only to be True other than when step == 0. self.initial_load_model_weights_only isn't used in this code path either. It's also already silently being changed in line 516 which is why I did it this way, but happy to change in the way you think is best

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is depending on your answer to my question for loading folder structure.
If I'm right in that question -- to load a HF safetensors format, it requires the user to either use initial_load_path or put it in a step-x folder. I think it makes sense to assert model_only == True here,

unless you think it makes sense to put the HF checkpoint in some step-10 / 20 / 50 folder (why? I can't think of such use cases). If that's the case we can modify the model_only logic on line 516 to reflect that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if there is no use case, will follow your advice, I wasn't sure

logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
dcp.load(states, checkpoint_id=checkpoint_id)
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
checkpoint_type=checkpoint_type,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -468,13 +566,33 @@ def _find_load_step(self, folder: str = "") -> int:

for filename in os.listdir(folder):
match = re.search(pattern, filename)
metadata_probe = os.path.join(folder, filename, ".metadata")
if match and os.path.isfile(metadata_probe):
dcp_metadata_probe = os.path.join(folder, filename, ".metadata")
safetensors_metadata_probe = os.path.join(
folder, filename, "model.safetensors.index.json"
)
if match and os.path.isfile(dcp_metadata_probe):
step_counts.append(int(match.group(1)))
elif match and os.path.isfile(safetensors_metadata_probe):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh for safetensors do we also require step-10, step-20, etc. type of checkpoint structure? If users want to load an existing checkpoint from HF, is the workflow

  1. download safetensors into local folder
  2. either put it in step-0 subfolder, or use the initial_load_path config?

Copy link
Contributor Author

@ankitageorge ankitageorge Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from your other comment, it seems like you think the HF checkpoint in step 10, 20 etc isn't a valid use case, so I will just follow your suggestion

step_counts.append(int(match.group(1)))
if not step_counts:
return -1
return max(step_counts)

def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType:
"""Find the checkpoint type for the given id.
Args:
checkpoint_id (str): The folder to find the checkpoint type for.
Returns:
CheckpointType: The checkpoint type for the given folder.
"""

for filename in os.listdir(checkpoint_id):
if filename == "model.safetensors.index.json":
return CheckpointType.SAFETENSORS
return CheckpointType.DCP

def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

Expand All @@ -486,8 +604,8 @@ def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.save_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
self.save_future = self.dcp_save(
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

Expand All @@ -499,7 +617,11 @@ def _ft_load(self) -> None:
begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add one more line, "because FT checkpoint currently only save/load dataloader.".

)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -568,7 +690,13 @@ def _save_last_step(self, curr_step: int) -> None:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
states = self._flattened_model_states_sd()

save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))
self.dcp_save(
states,
checkpoint_id=self._create_checkpoint_id(curr_step),
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
is_last_step=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we don't pass is_last_step in, since it's only used when HF format matters.
I suggest we pass in save_in_safetensors_format = self.last_save_in_safetensors_format (after config name change).

In the beginning of this function, we need to assert last_save_model_weights_only == True if last_save_in_safetensors_format

)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable_checkpoint:
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,17 @@ class Checkpoint:
for many steps or checkpointing too frequently. The default value is False.
"""

enable_save_safetensors_format: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit in naming

Suggested change
enable_save_safetensors_format: bool = False
last_save_in_safetensors_format: bool = False

"""
Enable the use of safetensors format for checkpointing. This will save the final checkpoints
in safetensors format instead of the default DCP format. There will be a performance
cost in using this as we need to consolidate the sharded tensors to full tensors as
a separate step. Last_save_model_weights must be true because safetensors doesn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a separate step. Last_save_model_weights must be true because safetensors doesn't
a separate step. last_save_model_weights_only must be true because safetensors doesn't

support saving non tensors. On load, this argument isn't needed as we will detect
whether the loaded checkpoint is in safetensors format or not.
The default value is False.
"""


@dataclass
class ActivationCheckpoint:
Expand Down
Loading