17
17
import torch
18
18
import torch .distributed as dist
19
19
import torch .distributed .checkpoint as dcp
20
- import torch .multiprocessing as mp
21
20
import torch .nn as nn
22
- from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
21
+ from torch .distributed .checkpoint . staging import DefaultStager , StagingOptions
23
22
from torch .distributed .checkpoint .state_dict import (
24
23
get_model_state_dict ,
25
24
set_model_state_dict ,
26
25
StateDictOptions ,
27
26
)
27
+ from torch .distributed .checkpoint .state_dict_saver import AsyncCheckpointerType
28
28
from torch .distributed .checkpoint .stateful import Stateful
29
29
from torch .utils .data import DataLoader
30
30
31
31
from torchtitan .components .ft import FTManager
32
32
from torchtitan .components .lr_scheduler import LRSchedulersContainer
33
33
from torchtitan .components .optimizer import OptimizersContainer
34
34
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
35
- from torchtitan .tools .logging import init_logger , logger
35
+ from torchtitan .tools .logging import logger
36
36
from torchtitan .tools .utils import GarbageCollection
37
37
38
38
@@ -98,43 +98,6 @@ def save_with_gc(state, checkpoint_id):
98
98
GarbageCollection .collect ("GC collection invoked by checkpointer." )
99
99
100
100
101
- def checkpoint_mp (recv : mp .Queue , send : mp .Queue ):
102
- """Process to save the checkpoint in the background.
103
-
104
- This is only used when async_checkpoint_with_pinned_memory is enabled.
105
-
106
- Args:
107
- recv (mp.Queue): The queue to receive the state_dict and Terminate signal.
108
- send (mp.Queue): The queue to send the SaveDone signal.
109
- """
110
- init_logger ()
111
- os .environ ["MASTER_PORT" ] = str (int (os .environ ["MASTER_PORT" ]) + 2 )
112
- os .environ ["TORCHELASTIC_USE_AGENT_STORE" ] = "False"
113
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
114
- dist .init_process_group ()
115
- try :
116
- while True :
117
- logger .debug ("Checkpoint background process is done." )
118
- send .put (SaveDone ())
119
- logger .debug ("Wait for the new state_dict." )
120
- obj = recv .get ()
121
- logger .debug ("Received the new state_dict." )
122
- if isinstance (obj , Terminate ):
123
- logger .info ("Terminating the checkpoint background process." )
124
- return
125
- assert isinstance (obj , tuple )
126
- begin = time .monotonic ()
127
- state , checkpoint_id = obj
128
- save_with_gc (state , checkpoint_id = checkpoint_id )
129
- logger .info (
130
- "Finish saving the checkpoint in the background process in %.2f seconds." ,
131
- time .monotonic () - begin ,
132
- )
133
- finally :
134
- logger .info ("Destroying the process group." )
135
- dist .destroy_process_group ()
136
-
137
-
138
101
def purge_thread (purge_queue : queue .Queue ):
139
102
"""Thread to purge the old checkpoints.
140
103
@@ -275,7 +238,7 @@ def load_state_dict(state_dict):
275
238
self .sending_to_checkpoint_mp = False
276
239
self .staging_id = None
277
240
self .cpu_offload_state_dict = None
278
- self .staging_stream = torch . cuda . Stream () if self . enable_staging else None
241
+ self .stager = None
279
242
280
243
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
281
244
@@ -292,7 +255,11 @@ def load_state_dict(state_dict):
292
255
293
256
# Async checkpoint related fields.
294
257
async_mode = ckpt_config .async_mode .lower ()
295
- if async_mode == AsyncMode .ASYNC or self .ft_manager :
258
+ if (
259
+ async_mode == AsyncMode .ASYNC
260
+ or async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
261
+ or self .ft_manager
262
+ ):
296
263
self .pg = dist .new_group (backend = "gloo" )
297
264
298
265
self .keep_latest_k = ckpt_config .keep_latest_k
@@ -311,25 +278,14 @@ def load_state_dict(state_dict):
311
278
self .purge_thread = None
312
279
313
280
self .mp = None
314
- self .async_future = None
281
+ self .staging_future = None
282
+ self .save_future = None
315
283
if async_mode == AsyncMode .DISABLED :
316
284
self .async_mode = AsyncMode .DISABLED
317
285
elif async_mode == AsyncMode .ASYNC :
318
286
self .async_mode = AsyncMode .ASYNC
319
287
elif async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
320
288
self .async_mode = AsyncMode .ASYNC_WITH_PINNED_MEM
321
- ctx = mp .get_context ("spawn" )
322
- self .mp_queue_send = ctx .Queue ()
323
- self .mp_queue_recv = ctx .Queue ()
324
- self .mp = ctx .Process (
325
- target = checkpoint_mp ,
326
- args = (
327
- self .mp_queue_send ,
328
- self .mp_queue_recv ,
329
- ),
330
- daemon = True ,
331
- )
332
- self .mp .start ()
333
289
else :
334
290
raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
335
291
@@ -353,6 +309,9 @@ def close(self):
353
309
self .purge_queue .put (Terminate ())
354
310
self .purge_thread .join ()
355
311
312
+ if self .stager is not None :
313
+ self .stager .close ()
314
+
356
315
@torch .no_grad ()
357
316
def save (self , curr_step : int , last_step : bool = False ) -> None :
358
317
"""Save the checkpoint for the current step.
@@ -388,10 +347,20 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
388
347
self ._save_last_step (curr_step )
389
348
elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
390
349
GarbageCollection .collect ("GC collection invoked by checkpointer." )
391
- self ._async_with_pinned_memory (checkpoint_id )
350
+ if self .stager is None :
351
+ self .stager = DefaultStager (StagingOptions (True , True , True , True ))
352
+ result = dcp .async_save (
353
+ self .states ,
354
+ checkpoint_id = checkpoint_id ,
355
+ process_group = self .pg ,
356
+ async_checkpointer_type = AsyncCheckpointerType .PROCESS ,
357
+ async_stager = self .stager ,
358
+ )
359
+ self .save_future = result .upload_completion
360
+ self .staging_future = result .staging_completion
392
361
elif self .async_mode == AsyncMode .ASYNC :
393
362
GarbageCollection .collect ("GC collection invoked by checkpointer." )
394
- self .async_future = dcp .async_save (
363
+ self .save_future = dcp .async_save (
395
364
self .states , checkpoint_id = checkpoint_id , process_group = self .pg
396
365
)
397
366
GarbageCollection .collect ("GC collection invoked by checkpointer." )
@@ -475,33 +444,7 @@ def maybe_wait_for_staging(self) -> None:
475
444
with ``async_checkpoint_with_pinned_memory``.
476
445
"""
477
446
if self .enable_staging and self .staging :
478
- if not self .staging_stream .query ():
479
- begin = time .monotonic ()
480
- self .staging_stream .synchronize ()
481
- logger .info (
482
- "Checkpointer waited staging %.2f seconds." ,
483
- time .monotonic () - begin ,
484
- )
485
- self .staging = False
486
-
487
- if self .sending_to_checkpoint_mp :
488
- # Copy the sync staging result to another process.
489
- def sync_func ():
490
- self .mp_queue_send .put_nowait (
491
- (self .cpu_offload_state_dict , self .staging_id )
492
- )
493
-
494
- # This may be a faster way to do zero-overhead checkpointing staging
495
- # checkpointing but we need more thorough investigation before
496
- # swithing to this method.
497
- # self.my_thread = threading.Thread(target=func).start()
498
- begin = time .monotonic ()
499
- sync_func ()
500
- logger .info (
501
- "Checkpointer sent staged state_dict to another process %.2f seconds" ,
502
- time .monotonic () - begin ,
503
- )
504
- self .sending_to_checkpoint_mp = False
447
+ self .staging_future .result ()
505
448
506
449
def _find_load_step (self , folder : str = "" ) -> int :
507
450
"""Find the step to load the checkpoint for.
@@ -540,7 +483,7 @@ def _ft_save(self, step: int) -> None:
540
483
begin = time .monotonic ()
541
484
self ._async_wait ()
542
485
checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
543
- self .async_future = dcp .async_save (
486
+ self .save_future = dcp .async_save (
544
487
self .ft_states , checkpoint_id = checkpoint_id , process_group = self .pg
545
488
)
546
489
logger .info (f"Staging ft checkpoint took { time .monotonic () - begin } secs." )
@@ -633,45 +576,18 @@ def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
633
576
634
577
def _async_wait (self ) -> None :
635
578
if self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
636
- logger .debug (
637
- f"Waiting for the background process to finish, { time .monotonic ()= } .:.2f"
638
- )
639
- if not self .mp .is_alive ():
640
- raise RuntimeError ("The checkpoint background process is dead." )
641
- _ = self .mp_queue_recv .get ()
579
+ if self .save_future is not None :
580
+ self .save_future .result ()
642
581
elif self .async_mode == AsyncMode .ASYNC or self .ft_manager is not None :
643
- if self .async_future is not None :
644
- self .async_future .result ()
645
- self .async_future = None
646
- elif self .async_future is not None :
582
+ if self .save_future is not None :
583
+ self .save_future .result ()
584
+ self .save_future = None
585
+ elif self .save_future is not None :
647
586
raise RuntimeError (
648
- "self.async_future is not None, but self.async_mode is not enabled "
587
+ "self.save_future is not None, but self.async_mode is not enabled "
649
588
"and fault tolerance is not active."
650
589
)
651
590
652
- def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
653
- self ._cpu_staging (checkpoint_id )
654
- self .sending_to_checkpoint_mp = True
655
-
656
- def _cpu_staging (self , checkpoint_id : str | None ) -> None :
657
- """Offload state_dict to CPU memory"""
658
- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
659
- if self .cpu_offload_state_dict is None :
660
- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
661
- self .cpu_offload_state_dict = _create_cpu_state_dict (
662
- state_dict , pin_memory = True , share_memory = True
663
- )
664
-
665
- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
666
- with torch .cuda .stream (self .staging_stream ):
667
- self .cpu_offload_state_dict = _copy_state_dict (
668
- state_dict ,
669
- self .cpu_offload_state_dict ,
670
- non_blocking = True ,
671
- )
672
- self .staging = True
673
- self .staging_id = checkpoint_id
674
-
675
591
def _purge_stale_checkpoints (self ):
676
592
if (
677
593
self .keep_latest_k > 0
0 commit comments