Skip to content

Commit f6e0022

Browse files
committed
add hf format dcp
1 parent aefe15a commit f6e0022

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

torchtitan/components/checkpoint.py

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@
1212
import shutil
1313
import threading
1414
import time
15-
from typing import Any
15+
from concurrent.futures import Future
16+
from gc import enable
17+
from typing import Any, Optional
1618

1719
import torch
1820
import torch.distributed as dist
1921
import torch.distributed.checkpoint as dcp
2022
import torch.multiprocessing as mp
2123
import torch.nn as nn
2224
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
25+
from torch.distributed.checkpoint import (
26+
HuggingFaceStorageReader,
27+
HuggingFaceStorageWriter,
28+
)
2329
from torch.distributed.checkpoint.state_dict import (
2430
get_model_state_dict,
2531
set_model_state_dict,
@@ -92,12 +98,6 @@ class SaveDone:
9298
pass
9399

94100

95-
@torch.no_grad()
96-
def save_with_gc(state, checkpoint_id):
97-
dcp.save(state, checkpoint_id=checkpoint_id)
98-
GarbageCollection.collect("GC collection invoked by checkpointer.")
99-
100-
101101
def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
102102
"""Process to save the checkpoint in the background.
103103
@@ -125,7 +125,9 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
125125
assert isinstance(obj, tuple)
126126
begin = time.monotonic()
127127
state, checkpoint_id = obj
128-
save_with_gc(state, checkpoint_id=checkpoint_id)
128+
save_with_gc(
129+
state, checkpoint_id=checkpoint_id, hf_safetensors_format=False
130+
)
129131
logger.info(
130132
"Finish saving the checkpoint in the background process in %.2f seconds.",
131133
time.monotonic() - begin,
@@ -135,6 +137,69 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
135137
dist.destroy_process_group()
136138

137139

140+
@torch.no_grad()
141+
def dcp_save(
142+
state_dict: dict[str, Any],
143+
checkpoint_id: str,
144+
is_async: bool,
145+
hf_safetensors_format: bool,
146+
pg: Optional[dist.ProcessGroup] = None,
147+
) -> Optional[Future]:
148+
"""Save the checkpoint for the current step.
149+
150+
151+
Args:
152+
state_dict (dict): The state dict to save.
153+
checkpoint_id (str): The checkpoint id to save.
154+
is_async (bool): Whether the checkpoint is async.
155+
"""
156+
if hf_safetensors_format:
157+
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
158+
if is_async:
159+
return dcp.async_save(
160+
state_dict, storage_writer=storage_writer, process_group=pg
161+
)
162+
else:
163+
return dcp.save(state_dict, storage_writer=storage_writer)
164+
else:
165+
if is_async:
166+
return dcp.async_save(
167+
state_dict, checkpoint_id=checkpoint_id, process_group=pg
168+
)
169+
else:
170+
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
171+
172+
173+
def dcp_load(
174+
state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
175+
) -> None:
176+
"""Save the checkpoint for the current step.
177+
178+
179+
Args:
180+
state_dict (dict): The state dict to load.
181+
checkpoint_id (str): The checkpoint id to load.
182+
"""
183+
if hf_safetensors_format:
184+
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
185+
dcp.load(state_dict, storage_writer=storage_reader)
186+
else:
187+
dcp.load(state_dict, checkpoint_id=checkpoint_id)
188+
189+
190+
@torch.no_grad()
191+
def save_with_gc(
192+
state: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
193+
) -> None:
194+
dcp_save(
195+
state,
196+
checkpoint_id=checkpoint_id,
197+
is_async=False,
198+
hf_safetensors_format=hf_safetensors_format,
199+
)
200+
GarbageCollection.collect("GC collection invoked by checkpointer.")
201+
202+
138203
def purge_thread(purge_queue: queue.Queue):
139204
"""Thread to purge the old checkpoints.
140205
@@ -227,6 +292,7 @@ def __init__(
227292
) -> None:
228293
ckpt_config = job_config.checkpoint
229294
self.enable_checkpoint = ckpt_config.enable_checkpoint
295+
self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format
230296
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
231297

232298
if self.ft_manager:
@@ -391,12 +457,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
391457
self._async_with_pinned_memory(checkpoint_id)
392458
elif self.async_mode == AsyncMode.ASYNC:
393459
GarbageCollection.collect("GC collection invoked by checkpointer.")
394-
self.async_future = dcp.async_save(
395-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
460+
self.async_future = dcp_save(
461+
self.states,
462+
checkpoint_id=checkpoint_id,
463+
is_async=True,
464+
hf_safetensors_format=self.enable_hf_safetensors_format,
465+
pg=self.pg,
396466
)
397467
GarbageCollection.collect("GC collection invoked by checkpointer.")
398468
else:
399-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
469+
self.save_with_gc(self.states, checkpoint_id=checkpoint_id)
400470
self._purge_stale_checkpoints()
401471

402472
logger.info(
@@ -461,7 +531,11 @@ def load(self, step: int = -1) -> bool:
461531
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
462532
begin = time.monotonic()
463533
states = self._states_to_load(model_only)
464-
dcp.load(states, checkpoint_id=checkpoint_id)
534+
dcp_load(
535+
states,
536+
checkpoint_id=checkpoint_id,
537+
hf_safetensors_format=self.enable_hf_safetensors_format,
538+
)
465539
GarbageCollection.collect("GC collection for checkpoint loading.")
466540
logger.info(
467541
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
@@ -540,8 +614,12 @@ def _ft_save(self, step: int) -> None:
540614
begin = time.monotonic()
541615
self._async_wait()
542616
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
543-
self.async_future = dcp.async_save(
544-
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
617+
self.async_future = dcp_save(
618+
self.ft_states,
619+
checkpoint_id=checkpoint_id,
620+
is_async=True,
621+
hf_safetensors_format=self.enable_hf_safetensors_format,
622+
pg=self.pg,
545623
)
546624
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
547625

@@ -553,7 +631,11 @@ def _ft_load(self) -> None:
553631
begin = time.monotonic()
554632
logger.info(f"Loading the FT checkpoint at step {step}.")
555633
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
556-
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
634+
dcp_load(
635+
self.ft_states,
636+
checkpoint_id=checkpoint_id,
637+
hf_safetensors_format=self.enable_hf_safetensors_format,
638+
)
557639
GarbageCollection.collect("GC collection for checkpoint loading.")
558640
logger.info(
559641
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
@@ -614,7 +696,9 @@ def _save_last_step(self, curr_step: int) -> None:
614696
else:
615697
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
616698

617-
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
699+
self.save_with_gc(
700+
self.states, checkpoint_id=self._create_checkpoint_id(curr_step)
701+
)
618702

619703
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
620704
if not self.enable_checkpoint:

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ class Checkpoint:
467467
for many steps or checkpointing too frequently. The default value is False.
468468
"""
469469

470+
enable_hf_safetensors_format: bool = False
471+
"""
472+
Enable the use of safetensors format for checkpointing. This will save checkpoints
473+
in safetensors format instead of the default DCP format. The default value is False.
474+
"""
475+
470476

471477
@dataclass
472478
class ActivationCheckpoint:

0 commit comments

Comments
 (0)