Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def build_test_list():
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_hf_safetensors_format",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_hf_safetensors_format",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint_hf_safetensors",
),
OverrideDefinitions(
[
[
Expand Down
112 changes: 100 additions & 12 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
import shutil
import threading
import time
from typing import Any
from concurrent.futures import Future
from typing import Any, Optional

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
Expand Down Expand Up @@ -93,8 +98,64 @@ class SaveDone:


@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
def dcp_save(
state_dict: dict[str, Any],
checkpoint_id: str,
is_async: bool,
hf_safetensors_format: bool,
pg: Optional[dist.ProcessGroup] = None,
) -> Optional[Future]:
"""Save the checkpoint with dcp.
Args:
state_dict (dict): The state dict to save.
checkpoint_id (str): The checkpoint id to save.
is_async (bool): Whether the checkpoint is async.
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
pg (Optional[dist.ProcessGroup]): The process group to use.
"""
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should simplify the function as follow

Suggested change
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None
checkpoint_id = checkpoint_id if not hf_safetensors_format else None
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id)



def dcp_load(
state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
) -> None:
"""Load the checkpoint with dcp.
Args:
state_dict (dict): The state dict to load.
checkpoint_id (str): The checkpoint id to load.
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
"""
if hf_safetensors_format:
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
dcp.load(state_dict, storage_writer=storage_reader)
else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)


@torch.no_grad()
def save_with_gc(
state: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
) -> None:
dcp_save(
state,
checkpoint_id=checkpoint_id,
is_async=False,
hf_safetensors_format=hf_safetensors_format,
)
GarbageCollection.collect("GC collection invoked by checkpointer.")


Expand Down Expand Up @@ -125,7 +186,9 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
assert isinstance(obj, tuple)
begin = time.monotonic()
state, checkpoint_id = obj
save_with_gc(state, checkpoint_id=checkpoint_id)
save_with_gc(
state, checkpoint_id=checkpoint_id, hf_safetensors_format=False
)
logger.info(
"Finish saving the checkpoint in the background process in %.2f seconds.",
time.monotonic() - begin,
Expand Down Expand Up @@ -227,6 +290,7 @@ def __init__(
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format
self.ft_manager = ft_manager.manager if ft_manager.enabled else None

if self.ft_manager:
Expand Down Expand Up @@ -391,12 +455,20 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
self.async_future = dcp_save(
self.states,
checkpoint_id=checkpoint_id,
is_async=True,
hf_safetensors_format=self.enable_hf_safetensors_format,
pg=self.pg,
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
save_with_gc(
self.states,
checkpoint_id=checkpoint_id,
hf_safetensors_format=self.enable_hf_safetensors_format,
)
self._purge_stale_checkpoints()

logger.info(
Expand Down Expand Up @@ -461,7 +533,11 @@ def load(self, step: int = -1) -> bool:
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
dcp.load(states, checkpoint_id=checkpoint_id)
dcp_load(
states,
checkpoint_id=checkpoint_id,
hf_safetensors_format=self.enable_hf_safetensors_format,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -540,8 +616,12 @@ def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.async_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
self.async_future = dcp_save(
self.ft_states,
checkpoint_id=checkpoint_id,
is_async=True,
hf_safetensors_format=self.enable_hf_safetensors_format,
pg=self.pg,
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

Expand All @@ -553,7 +633,11 @@ def _ft_load(self) -> None:
begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
hf_safetensors_format=self.enable_hf_safetensors_format,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -614,7 +698,11 @@ def _save_last_step(self, curr_step: int) -> None:
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
save_with_gc(
self.states,
checkpoint_id=self._create_checkpoint_id(curr_step),
hf_safetensors_format=self.enable_hf_safetensors_format,
)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable_checkpoint:
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ class Checkpoint:
for many steps or checkpointing too frequently. The default value is False.
"""

enable_hf_safetensors_format: bool = False
"""
Enable the use of safetensors format for checkpointing. This will save checkpoints
in safetensors format instead of the default DCP format. The default value is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also mention the possible performance penalty? It's not cost free, right?

"""


@dataclass
class ActivationCheckpoint:
Expand Down
Loading