12
12
import shutil
13
13
import threading
14
14
import time
15
- from typing import Any
15
+ from concurrent .futures import Future
16
+ from gc import enable
17
+ from typing import Any , Optional
16
18
17
19
import torch
18
20
import torch .distributed as dist
19
21
import torch .distributed .checkpoint as dcp
20
22
import torch .multiprocessing as mp
21
23
import torch .nn as nn
22
24
from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
25
+ from torch .distributed .checkpoint import (
26
+ HuggingFaceStorageReader ,
27
+ HuggingFaceStorageWriter ,
28
+ )
23
29
from torch .distributed .checkpoint .state_dict import (
24
30
get_model_state_dict ,
25
31
set_model_state_dict ,
@@ -92,12 +98,6 @@ class SaveDone:
92
98
pass
93
99
94
100
95
- @torch .no_grad ()
96
- def save_with_gc (state , checkpoint_id ):
97
- dcp .save (state , checkpoint_id = checkpoint_id )
98
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
99
-
100
-
101
101
def checkpoint_mp (recv : mp .Queue , send : mp .Queue ):
102
102
"""Process to save the checkpoint in the background.
103
103
@@ -125,7 +125,9 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
125
125
assert isinstance (obj , tuple )
126
126
begin = time .monotonic ()
127
127
state , checkpoint_id = obj
128
- save_with_gc (state , checkpoint_id = checkpoint_id )
128
+ save_with_gc (
129
+ state , checkpoint_id = checkpoint_id , hf_safetensors_format = False
130
+ )
129
131
logger .info (
130
132
"Finish saving the checkpoint in the background process in %.2f seconds." ,
131
133
time .monotonic () - begin ,
@@ -135,6 +137,69 @@ def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
135
137
dist .destroy_process_group ()
136
138
137
139
140
+ @torch .no_grad ()
141
+ def dcp_save (
142
+ state_dict : dict [str , Any ],
143
+ checkpoint_id : str ,
144
+ is_async : bool ,
145
+ hf_safetensors_format : bool ,
146
+ pg : Optional [dist .ProcessGroup ] = None ,
147
+ ) -> Optional [Future ]:
148
+ """Save the checkpoint for the current step.
149
+
150
+
151
+ Args:
152
+ state_dict (dict): The state dict to save.
153
+ checkpoint_id (str): The checkpoint id to save.
154
+ is_async (bool): Whether the checkpoint is async.
155
+ """
156
+ if hf_safetensors_format :
157
+ storage_writer = HuggingFaceStorageWriter (path = checkpoint_id , save_sharded = True )
158
+ if is_async :
159
+ return dcp .async_save (
160
+ state_dict , storage_writer = storage_writer , process_group = pg
161
+ )
162
+ else :
163
+ return dcp .save (state_dict , storage_writer = storage_writer )
164
+ else :
165
+ if is_async :
166
+ return dcp .async_save (
167
+ state_dict , checkpoint_id = checkpoint_id , process_group = pg
168
+ )
169
+ else :
170
+ return dcp .save (state_dict , checkpoint_id = checkpoint_id )
171
+
172
+
173
+ def dcp_load (
174
+ state_dict : dict [str , Any ], checkpoint_id : str , hf_safetensors_format : bool
175
+ ) -> None :
176
+ """Save the checkpoint for the current step.
177
+
178
+
179
+ Args:
180
+ state_dict (dict): The state dict to load.
181
+ checkpoint_id (str): The checkpoint id to load.
182
+ """
183
+ if hf_safetensors_format :
184
+ storage_reader = HuggingFaceStorageReader (path = checkpoint_id )
185
+ dcp .load (state_dict , storage_writer = storage_reader )
186
+ else :
187
+ dcp .load (state_dict , checkpoint_id = checkpoint_id )
188
+
189
+
190
+ @torch .no_grad ()
191
+ def save_with_gc (
192
+ state : dict [str , Any ], checkpoint_id : str , hf_safetensors_format : bool
193
+ ) -> None :
194
+ dcp_save (
195
+ state ,
196
+ checkpoint_id = checkpoint_id ,
197
+ is_async = False ,
198
+ hf_safetensors_format = hf_safetensors_format ,
199
+ )
200
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
201
+
202
+
138
203
def purge_thread (purge_queue : queue .Queue ):
139
204
"""Thread to purge the old checkpoints.
140
205
@@ -227,6 +292,7 @@ def __init__(
227
292
) -> None :
228
293
ckpt_config = job_config .checkpoint
229
294
self .enable_checkpoint = ckpt_config .enable_checkpoint
295
+ self .enable_hf_safetensors_format = ckpt_config .enable_hf_safetensors_format
230
296
self .ft_manager = ft_manager .manager if ft_manager .enabled else None
231
297
232
298
if self .ft_manager :
@@ -391,12 +457,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
391
457
self ._async_with_pinned_memory (checkpoint_id )
392
458
elif self .async_mode == AsyncMode .ASYNC :
393
459
GarbageCollection .collect ("GC collection invoked by checkpointer." )
394
- self .async_future = dcp .async_save (
395
- self .states , checkpoint_id = checkpoint_id , process_group = self .pg
460
+ self .async_future = dcp_save (
461
+ self .states ,
462
+ checkpoint_id = checkpoint_id ,
463
+ is_async = True ,
464
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
465
+ pg = self .pg ,
396
466
)
397
467
GarbageCollection .collect ("GC collection invoked by checkpointer." )
398
468
else :
399
- save_with_gc (self .states , checkpoint_id = checkpoint_id )
469
+ self . save_with_gc (self .states , checkpoint_id = checkpoint_id )
400
470
self ._purge_stale_checkpoints ()
401
471
402
472
logger .info (
@@ -461,7 +531,11 @@ def load(self, step: int = -1) -> bool:
461
531
logger .info (f"Loading the checkpoint from { checkpoint_id } ." )
462
532
begin = time .monotonic ()
463
533
states = self ._states_to_load (model_only )
464
- dcp .load (states , checkpoint_id = checkpoint_id )
534
+ dcp_load (
535
+ states ,
536
+ checkpoint_id = checkpoint_id ,
537
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
538
+ )
465
539
GarbageCollection .collect ("GC collection for checkpoint loading." )
466
540
logger .info (
467
541
f"Finished loading the checkpoint in { time .monotonic () - begin :.2f} seconds."
@@ -540,8 +614,12 @@ def _ft_save(self, step: int) -> None:
540
614
begin = time .monotonic ()
541
615
self ._async_wait ()
542
616
checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
543
- self .async_future = dcp .async_save (
544
- self .ft_states , checkpoint_id = checkpoint_id , process_group = self .pg
617
+ self .async_future = dcp_save (
618
+ self .ft_states ,
619
+ checkpoint_id = checkpoint_id ,
620
+ is_async = True ,
621
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
622
+ pg = self .pg ,
545
623
)
546
624
logger .info (f"Staging ft checkpoint took { time .monotonic () - begin } secs." )
547
625
@@ -553,7 +631,11 @@ def _ft_load(self) -> None:
553
631
begin = time .monotonic ()
554
632
logger .info (f"Loading the FT checkpoint at step { step } ." )
555
633
checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
556
- dcp .load (self .ft_states , checkpoint_id = checkpoint_id )
634
+ dcp_load (
635
+ self .ft_states ,
636
+ checkpoint_id = checkpoint_id ,
637
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
638
+ )
557
639
GarbageCollection .collect ("GC collection for checkpoint loading." )
558
640
logger .info (
559
641
f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
@@ -614,7 +696,9 @@ def _save_last_step(self, curr_step: int) -> None:
614
696
else :
615
697
logger .info (f"Saving a full checkpoint at last step, step { curr_step } ." )
616
698
617
- save_with_gc (self .states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
699
+ self .save_with_gc (
700
+ self .states , checkpoint_id = self ._create_checkpoint_id (curr_step )
701
+ )
618
702
619
703
def _should_save (self , curr_step : int , last_step : bool = False ) -> bool :
620
704
if not self .enable_checkpoint :
0 commit comments