12
12
import shutil
13
13
import threading
14
14
import time
15
- from typing import Any
15
+ from concurrent .futures import Future
16
+ from typing import Any , Optional
16
17
17
18
import torch
18
19
import torch .distributed as dist
19
20
import torch .distributed .checkpoint as dcp
20
21
import torch .multiprocessing as mp
21
22
import torch .nn as nn
22
23
from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
24
+ from torch .distributed .checkpoint import (
25
+ HuggingFaceStorageReader ,
26
+ HuggingFaceStorageWriter ,
27
+ )
23
28
from torch .distributed .checkpoint .state_dict import (
24
29
get_model_state_dict ,
25
30
set_model_state_dict ,
@@ -93,8 +98,64 @@ class SaveDone:
93
98
94
99
95
100
@torch .no_grad ()
96
- def save_with_gc (state , checkpoint_id ):
97
- dcp .save (state , checkpoint_id = checkpoint_id )
101
+ def dcp_save (
102
+ state_dict : dict [str , Any ],
103
+ checkpoint_id : str ,
104
+ is_async : bool ,
105
+ hf_safetensors_format : bool ,
106
+ pg : Optional [dist .ProcessGroup ] = None ,
107
+ ) -> Optional [Future ]:
108
+ """Save the checkpoint with dcp.
109
+ Args:
110
+ state_dict (dict): The state dict to save.
111
+ checkpoint_id (str): The checkpoint id to save.
112
+ is_async (bool): Whether the checkpoint is async.
113
+ hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
114
+ pg (Optional[dist.ProcessGroup]): The process group to use.
115
+ """
116
+ if hf_safetensors_format :
117
+ storage_writer = HuggingFaceStorageWriter (path = checkpoint_id , save_sharded = True )
118
+ if is_async :
119
+ return dcp .async_save (
120
+ state_dict , storage_writer = storage_writer , process_group = pg
121
+ )
122
+ else :
123
+ return dcp .save (state_dict , storage_writer = storage_writer )
124
+ else :
125
+ if is_async :
126
+ return dcp .async_save (
127
+ state_dict , checkpoint_id = checkpoint_id , process_group = pg
128
+ )
129
+ else :
130
+ return dcp .save (state_dict , checkpoint_id = checkpoint_id )
131
+
132
+
133
+ def dcp_load (
134
+ state_dict : dict [str , Any ], checkpoint_id : str , hf_safetensors_format : bool
135
+ ) -> None :
136
+ """Load the checkpoint with dcp.
137
+ Args:
138
+ state_dict (dict): The state dict to load.
139
+ checkpoint_id (str): The checkpoint id to load.
140
+ hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
141
+ """
142
+ if hf_safetensors_format :
143
+ storage_reader = HuggingFaceStorageReader (path = checkpoint_id )
144
+ dcp .load (state_dict , storage_writer = storage_reader )
145
+ else :
146
+ dcp .load (state_dict , checkpoint_id = checkpoint_id )
147
+
148
+
149
+ @torch .no_grad ()
150
+ def save_with_gc (
151
+ state : dict [str , Any ], checkpoint_id : str , hf_safetensors_format : bool
152
+ ) -> None :
153
+ dcp_save (
154
+ state ,
155
+ checkpoint_id = checkpoint_id ,
156
+ is_async = False ,
157
+ hf_safetensors_format = hf_safetensors_format ,
158
+ )
98
159
GarbageCollection .collect ("GC collection invoked by checkpointer." )
99
160
100
161
@@ -125,7 +186,9 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
125
186
assert isinstance (obj , tuple )
126
187
begin = time .monotonic ()
127
188
state , checkpoint_id = obj
128
- save_with_gc (state , checkpoint_id = checkpoint_id )
189
+ save_with_gc (
190
+ state , checkpoint_id = checkpoint_id , hf_safetensors_format = False
191
+ )
129
192
logger .info (
130
193
"Finish saving the checkpoint in the background process in %.2f seconds." ,
131
194
time .monotonic () - begin ,
@@ -227,6 +290,7 @@ def __init__(
227
290
) -> None :
228
291
ckpt_config = job_config .checkpoint
229
292
self .enable_checkpoint = ckpt_config .enable_checkpoint
293
+ self .enable_hf_safetensors_format = ckpt_config .enable_hf_safetensors_format
230
294
self .ft_manager = ft_manager .manager if ft_manager .enabled else None
231
295
232
296
if self .ft_manager :
@@ -391,12 +455,20 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
391
455
self ._async_with_pinned_memory (checkpoint_id )
392
456
elif self .async_mode == AsyncMode .ASYNC :
393
457
GarbageCollection .collect ("GC collection invoked by checkpointer." )
394
- self .async_future = dcp .async_save (
395
- self .states , checkpoint_id = checkpoint_id , process_group = self .pg
458
+ self .async_future = dcp_save (
459
+ self .states ,
460
+ checkpoint_id = checkpoint_id ,
461
+ is_async = True ,
462
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
463
+ pg = self .pg ,
396
464
)
397
465
GarbageCollection .collect ("GC collection invoked by checkpointer." )
398
466
else :
399
- save_with_gc (self .states , checkpoint_id = checkpoint_id )
467
+ save_with_gc (
468
+ self .states ,
469
+ checkpoint_id = checkpoint_id ,
470
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
471
+ )
400
472
self ._purge_stale_checkpoints ()
401
473
402
474
logger .info (
@@ -461,7 +533,11 @@ def load(self, step: int = -1) -> bool:
461
533
logger .info (f"Loading the checkpoint from { checkpoint_id } ." )
462
534
begin = time .monotonic ()
463
535
states = self ._states_to_load (model_only )
464
- dcp .load (states , checkpoint_id = checkpoint_id )
536
+ dcp_load (
537
+ states ,
538
+ checkpoint_id = checkpoint_id ,
539
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
540
+ )
465
541
GarbageCollection .collect ("GC collection for checkpoint loading." )
466
542
logger .info (
467
543
f"Finished loading the checkpoint in { time .monotonic () - begin :.2f} seconds."
@@ -540,8 +616,12 @@ def _ft_save(self, step: int) -> None:
540
616
begin = time .monotonic ()
541
617
self ._async_wait ()
542
618
checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
543
- self .async_future = dcp .async_save (
544
- self .ft_states , checkpoint_id = checkpoint_id , process_group = self .pg
619
+ self .async_future = dcp_save (
620
+ self .ft_states ,
621
+ checkpoint_id = checkpoint_id ,
622
+ is_async = True ,
623
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
624
+ pg = self .pg ,
545
625
)
546
626
logger .info (f"Staging ft checkpoint took { time .monotonic () - begin } secs." )
547
627
@@ -553,7 +633,11 @@ def _ft_load(self) -> None:
553
633
begin = time .monotonic ()
554
634
logger .info (f"Loading the FT checkpoint at step { step } ." )
555
635
checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
556
- dcp .load (self .ft_states , checkpoint_id = checkpoint_id )
636
+ dcp_load (
637
+ self .ft_states ,
638
+ checkpoint_id = checkpoint_id ,
639
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
640
+ )
557
641
GarbageCollection .collect ("GC collection for checkpoint loading." )
558
642
logger .info (
559
643
f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
@@ -614,7 +698,11 @@ def _save_last_step(self, curr_step: int) -> None:
614
698
else :
615
699
logger .info (f"Saving a full checkpoint at last step, step { curr_step } ." )
616
700
617
- save_with_gc (self .states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
701
+ save_with_gc (
702
+ self .states ,
703
+ checkpoint_id = self ._create_checkpoint_id (curr_step ),
704
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
705
+ )
618
706
619
707
def _should_save (self , curr_step : int , last_step : bool = False ) -> bool :
620
708
if not self .enable_checkpoint :
0 commit comments