Skip to content

Commit c0c4448

Browse files
committed
add hf support
1 parent aefe15a commit c0c4448

File tree

3 files changed

+121
-12
lines changed

3 files changed

+121
-12
lines changed

tests/integration_tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,21 @@ def build_test_list():
118118
"Checkpoint Integration Test - Save Load Full Checkpoint",
119119
"full_checkpoint",
120120
),
121+
OverrideDefinitions(
122+
[
123+
[
124+
"--checkpoint.enable_checkpoint",
125+
"--checkpoint.enable_hf_safetensors_format",
126+
],
127+
[
128+
"--checkpoint.enable_checkpoint",
129+
"--checkpoint.enable_hf_safetensors_format",
130+
"--training.steps 20",
131+
],
132+
],
133+
"Checkpoint Integration Test - Save Load Full Checkpoint",
134+
"full_checkpoint_hf_safetensors",
135+
),
121136
OverrideDefinitions(
122137
[
123138
[

torchtitan/components/checkpoint.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
import shutil
1313
import threading
1414
import time
15-
from typing import Any
15+
from concurrent.futures import Future
16+
from typing import Any, Optional
1617

1718
import torch
1819
import torch.distributed as dist
1920
import torch.distributed.checkpoint as dcp
2021
import torch.multiprocessing as mp
2122
import torch.nn as nn
2223
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
24+
from torch.distributed.checkpoint import (
25+
HuggingFaceStorageReader,
26+
HuggingFaceStorageWriter,
27+
)
2328
from torch.distributed.checkpoint.state_dict import (
2429
get_model_state_dict,
2530
set_model_state_dict,
@@ -93,8 +98,64 @@ class SaveDone:
9398

9499

95100
@torch.no_grad()
96-
def save_with_gc(state, checkpoint_id):
97-
dcp.save(state, checkpoint_id=checkpoint_id)
101+
def dcp_save(
102+
state_dict: dict[str, Any],
103+
checkpoint_id: str,
104+
is_async: bool,
105+
hf_safetensors_format: bool,
106+
pg: Optional[dist.ProcessGroup] = None,
107+
) -> Optional[Future]:
108+
"""Save the checkpoint with dcp.
109+
Args:
110+
state_dict (dict): The state dict to save.
111+
checkpoint_id (str): The checkpoint id to save.
112+
is_async (bool): Whether the checkpoint is async.
113+
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
114+
pg (Optional[dist.ProcessGroup]): The process group to use.
115+
"""
116+
if hf_safetensors_format:
117+
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
118+
if is_async:
119+
return dcp.async_save(
120+
state_dict, storage_writer=storage_writer, process_group=pg
121+
)
122+
else:
123+
return dcp.save(state_dict, storage_writer=storage_writer)
124+
else:
125+
if is_async:
126+
return dcp.async_save(
127+
state_dict, checkpoint_id=checkpoint_id, process_group=pg
128+
)
129+
else:
130+
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
131+
132+
133+
def dcp_load(
134+
state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
135+
) -> None:
136+
"""Load the checkpoint with dcp.
137+
Args:
138+
state_dict (dict): The state dict to load.
139+
checkpoint_id (str): The checkpoint id to load.
140+
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
141+
"""
142+
if hf_safetensors_format:
143+
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
144+
dcp.load(state_dict, storage_writer=storage_reader)
145+
else:
146+
dcp.load(state_dict, checkpoint_id=checkpoint_id)
147+
148+
149+
@torch.no_grad()
150+
def save_with_gc(
151+
state: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
152+
) -> None:
153+
dcp_save(
154+
state,
155+
checkpoint_id=checkpoint_id,
156+
is_async=False,
157+
hf_safetensors_format=hf_safetensors_format,
158+
)
98159
GarbageCollection.collect("GC collection invoked by checkpointer.")
99160

100161

@@ -125,7 +186,9 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
125186
assert isinstance(obj, tuple)
126187
begin = time.monotonic()
127188
state, checkpoint_id = obj
128-
save_with_gc(state, checkpoint_id=checkpoint_id)
189+
save_with_gc(
190+
state, checkpoint_id=checkpoint_id, hf_safetensors_format=False
191+
)
129192
logger.info(
130193
"Finish saving the checkpoint in the background process in %.2f seconds.",
131194
time.monotonic() - begin,
@@ -227,6 +290,7 @@ def __init__(
227290
) -> None:
228291
ckpt_config = job_config.checkpoint
229292
self.enable_checkpoint = ckpt_config.enable_checkpoint
293+
self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format
230294
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
231295

232296
if self.ft_manager:
@@ -391,12 +455,20 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
391455
self._async_with_pinned_memory(checkpoint_id)
392456
elif self.async_mode == AsyncMode.ASYNC:
393457
GarbageCollection.collect("GC collection invoked by checkpointer.")
394-
self.async_future = dcp.async_save(
395-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
458+
self.async_future = dcp_save(
459+
self.states,
460+
checkpoint_id=checkpoint_id,
461+
is_async=True,
462+
hf_safetensors_format=self.enable_hf_safetensors_format,
463+
pg=self.pg,
396464
)
397465
GarbageCollection.collect("GC collection invoked by checkpointer.")
398466
else:
399-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
467+
save_with_gc(
468+
self.states,
469+
checkpoint_id=checkpoint_id,
470+
hf_safetensors_format=self.enable_hf_safetensors_format,
471+
)
400472
self._purge_stale_checkpoints()
401473

402474
logger.info(
@@ -461,7 +533,11 @@ def load(self, step: int = -1) -> bool:
461533
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
462534
begin = time.monotonic()
463535
states = self._states_to_load(model_only)
464-
dcp.load(states, checkpoint_id=checkpoint_id)
536+
dcp_load(
537+
states,
538+
checkpoint_id=checkpoint_id,
539+
hf_safetensors_format=self.enable_hf_safetensors_format,
540+
)
465541
GarbageCollection.collect("GC collection for checkpoint loading.")
466542
logger.info(
467543
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
@@ -540,8 +616,12 @@ def _ft_save(self, step: int) -> None:
540616
begin = time.monotonic()
541617
self._async_wait()
542618
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
543-
self.async_future = dcp.async_save(
544-
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
619+
self.async_future = dcp_save(
620+
self.ft_states,
621+
checkpoint_id=checkpoint_id,
622+
is_async=True,
623+
hf_safetensors_format=self.enable_hf_safetensors_format,
624+
pg=self.pg,
545625
)
546626
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
547627

@@ -553,7 +633,11 @@ def _ft_load(self) -> None:
553633
begin = time.monotonic()
554634
logger.info(f"Loading the FT checkpoint at step {step}.")
555635
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
556-
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
636+
dcp_load(
637+
self.ft_states,
638+
checkpoint_id=checkpoint_id,
639+
hf_safetensors_format=self.enable_hf_safetensors_format,
640+
)
557641
GarbageCollection.collect("GC collection for checkpoint loading.")
558642
logger.info(
559643
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
@@ -614,7 +698,11 @@ def _save_last_step(self, curr_step: int) -> None:
614698
else:
615699
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
616700

617-
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
701+
save_with_gc(
702+
self.states,
703+
checkpoint_id=self._create_checkpoint_id(curr_step),
704+
hf_safetensors_format=self.enable_hf_safetensors_format,
705+
)
618706

619707
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
620708
if not self.enable_checkpoint:

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ class Checkpoint:
467467
for many steps or checkpointing too frequently. The default value is False.
468468
"""
469469

470+
enable_hf_safetensors_format: bool = False
471+
"""
472+
Enable the use of safetensors format for checkpointing. This will save checkpoints
473+
in safetensors format instead of the default DCP format. The default value is False.
474+
"""
475+
470476

471477
@dataclass
472478
class ActivationCheckpoint:

0 commit comments

Comments
 (0)