@@ -98,45 +98,6 @@ class SaveDone:
98
98
pass
99
99
100
100
101
- def checkpoint_mp (recv : mp .Queue , send : mp .Queue ):
102
- """Process to save the checkpoint in the background.
103
-
104
- This is only used when async_checkpoint_with_pinned_memory is enabled.
105
-
106
- Args:
107
- recv (mp.Queue): The queue to receive the state_dict and Terminate signal.
108
- send (mp.Queue): The queue to send the SaveDone signal.
109
- """
110
- init_logger ()
111
- os .environ ["MASTER_PORT" ] = str (int (os .environ ["MASTER_PORT" ]) + 2 )
112
- os .environ ["TORCHELASTIC_USE_AGENT_STORE" ] = "False"
113
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
114
- dist .init_process_group ()
115
- try :
116
- while True :
117
- logger .debug ("Checkpoint background process is done." )
118
- send .put (SaveDone ())
119
- logger .debug ("Wait for the new state_dict." )
120
- obj = recv .get ()
121
- logger .debug ("Received the new state_dict." )
122
- if isinstance (obj , Terminate ):
123
- logger .info ("Terminating the checkpoint background process." )
124
- return
125
- assert isinstance (obj , tuple )
126
- begin = time .monotonic ()
127
- state , checkpoint_id = obj
128
- save_with_gc (
129
- state , checkpoint_id = checkpoint_id , hf_safetensors_format = False
130
- )
131
- logger .info (
132
- "Finish saving the checkpoint in the background process in %.2f seconds." ,
133
- time .monotonic () - begin ,
134
- )
135
- finally :
136
- logger .info ("Destroying the process group." )
137
- dist .destroy_process_group ()
138
-
139
-
140
101
@torch .no_grad ()
141
102
def dcp_save (
142
103
state_dict : dict [str , Any ],
@@ -145,13 +106,13 @@ def dcp_save(
145
106
hf_safetensors_format : bool ,
146
107
pg : Optional [dist .ProcessGroup ] = None ,
147
108
) -> Optional [Future ]:
148
- """Save the checkpoint for the current step.
149
-
150
-
109
+ """Save the checkpoint with dcp.
151
110
Args:
152
111
state_dict (dict): The state dict to save.
153
112
checkpoint_id (str): The checkpoint id to save.
154
113
is_async (bool): Whether the checkpoint is async.
114
+ hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
115
+ pg (Optional[dist.ProcessGroup]): The process group to use.
155
116
"""
156
117
if hf_safetensors_format :
157
118
storage_writer = HuggingFaceStorageWriter (path = checkpoint_id , save_sharded = True )
@@ -173,12 +134,11 @@ def dcp_save(
173
134
def dcp_load (
174
135
state_dict : dict [str , Any ], checkpoint_id : str , hf_safetensors_format : bool
175
136
) -> None :
176
- """Save the checkpoint for the current step.
177
-
178
-
137
+ """Load the checkpoint with dcp.
179
138
Args:
180
139
state_dict (dict): The state dict to load.
181
140
checkpoint_id (str): The checkpoint id to load.
141
+ hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
182
142
"""
183
143
if hf_safetensors_format :
184
144
storage_reader = HuggingFaceStorageReader (path = checkpoint_id )
@@ -200,6 +160,45 @@ def save_with_gc(
200
160
GarbageCollection .collect ("GC collection invoked by checkpointer." )
201
161
202
162
163
+ def checkpoint_mp (recv : mp .Queue , send : mp .Queue ):
164
+ """Process to save the checkpoint in the background.
165
+
166
+ This is only used when async_checkpoint_with_pinned_memory is enabled.
167
+
168
+ Args:
169
+ recv (mp.Queue): The queue to receive the state_dict and Terminate signal.
170
+ send (mp.Queue): The queue to send the SaveDone signal.
171
+ """
172
+ init_logger ()
173
+ os .environ ["MASTER_PORT" ] = str (int (os .environ ["MASTER_PORT" ]) + 2 )
174
+ os .environ ["TORCHELASTIC_USE_AGENT_STORE" ] = "False"
175
+ torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
176
+ dist .init_process_group ()
177
+ try :
178
+ while True :
179
+ logger .debug ("Checkpoint background process is done." )
180
+ send .put (SaveDone ())
181
+ logger .debug ("Wait for the new state_dict." )
182
+ obj = recv .get ()
183
+ logger .debug ("Received the new state_dict." )
184
+ if isinstance (obj , Terminate ):
185
+ logger .info ("Terminating the checkpoint background process." )
186
+ return
187
+ assert isinstance (obj , tuple )
188
+ begin = time .monotonic ()
189
+ state , checkpoint_id = obj
190
+ save_with_gc (
191
+ state , checkpoint_id = checkpoint_id , hf_safetensors_format = False
192
+ )
193
+ logger .info (
194
+ "Finish saving the checkpoint in the background process in %.2f seconds." ,
195
+ time .monotonic () - begin ,
196
+ )
197
+ finally :
198
+ logger .info ("Destroying the process group." )
199
+ dist .destroy_process_group ()
200
+
201
+
203
202
def purge_thread (purge_queue : queue .Queue ):
204
203
"""Thread to purge the old checkpoints.
205
204
@@ -466,7 +465,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
466
465
)
467
466
GarbageCollection .collect ("GC collection invoked by checkpointer." )
468
467
else :
469
- self .save_with_gc (self .states , checkpoint_id = checkpoint_id )
468
+ save_with_gc (
469
+ self .states ,
470
+ checkpoint_id = checkpoint_id ,
471
+ hf_safetensors_format = self .enable_hf_safetensors_format ,
472
+ )
470
473
self ._purge_stale_checkpoints ()
471
474
472
475
logger .info (
0 commit comments