@@ -345,12 +345,15 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
345
345
# freed until _async_wait()
346
346
if last_step :
347
347
self ._save_last_step (curr_step )
348
- elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
348
+ return
349
+
350
+ states = self ._flattened_model_states_sd ()
351
+ if self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
349
352
GarbageCollection .collect ("GC collection invoked by checkpointer." )
350
353
if self .stager is None :
351
354
self .stager = DefaultStager (StagingOptions (True , True , True , True ))
352
355
result = dcp .async_save (
353
- self . states ,
356
+ states ,
354
357
checkpoint_id = checkpoint_id ,
355
358
process_group = self .pg ,
356
359
async_checkpointer_type = AsyncCheckpointerType .PROCESS ,
@@ -361,11 +364,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
361
364
elif self .async_mode == AsyncMode .ASYNC :
362
365
GarbageCollection .collect ("GC collection invoked by checkpointer." )
363
366
self .save_future = dcp .async_save (
364
- self . states , checkpoint_id = checkpoint_id , process_group = self .pg
367
+ states , checkpoint_id = checkpoint_id , process_group = self .pg
365
368
)
366
369
GarbageCollection .collect ("GC collection invoked by checkpointer." )
367
370
else :
368
- save_with_gc (self . states , checkpoint_id = checkpoint_id )
371
+ save_with_gc (states , checkpoint_id = checkpoint_id )
369
372
self ._purge_stale_checkpoints ()
370
373
371
374
logger .info (
@@ -502,6 +505,19 @@ def _ft_load(self) -> None:
502
505
f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
503
506
)
504
507
508
+ def _flattened_model_states_sd (
509
+ self , state_dict : dict [str , Any ] | None = None
510
+ ) -> dict [str , Any ]:
511
+ """Flatten the model states into a single dictionary.
512
+
513
+ Note that other states, such as optimizer states, are not flattened.
514
+ """
515
+ states = state_dict if state_dict is not None else self .states
516
+ sd = {k : v for k , v in states .items () if k != MODEL }
517
+ if MODEL in states :
518
+ sd .update (states [MODEL ].state_dict ())
519
+ return sd
520
+
505
521
def _states_to_load (self , model_only : bool ) -> dict [str , Any ]:
506
522
"""Determines which states to load for the given step.
507
523
@@ -516,8 +532,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
516
532
"""
517
533
# For the first step, we will only load the model weights.
518
534
if model_only :
519
- sd = self .states [MODEL ].state_dict ()
520
- return sd
535
+ return self .states [MODEL ].state_dict ()
521
536
522
537
for exclude_key in self .exclude_from_loading :
523
538
if exclude_key not in self .states :
@@ -527,6 +542,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
527
542
k : v for k , v in self .states .items () if k not in self .exclude_from_loading
528
543
}
529
544
545
+ states_to_load = self ._flattened_model_states_sd (states_to_load )
546
+
530
547
if self .ft_manager :
531
548
states_to_load .pop (DATALOADER )
532
549
@@ -539,25 +556,19 @@ def _save_last_step(self, curr_step: int) -> None:
539
556
# current dtype is not the same as the export dtype at the end of the training.
540
557
541
558
if self .last_save_model_weights_only :
542
- # We update self.states to keep the model only.
543
- # After this update, self.states = {
544
- # 'tok_embeddings.weight':...,
545
- # 'layers.0.attention.wq.weight': ...
546
- # }.
547
- self .states = self .states [MODEL ].state_dict ()
559
+ states = self .states [MODEL ].state_dict ()
548
560
549
561
if self .export_dtype != torch .float32 :
550
- self .states = {
551
- k : v .to (self .export_dtype ) for k , v in self .states .items ()
552
- }
562
+ states = {k : v .to (self .export_dtype ) for k , v in states .items ()}
553
563
logger .info (
554
564
f"Saving a model weights only checkpoint in { self .export_dtype } "
555
565
f"at last step, step { curr_step } ."
556
566
)
557
567
else :
558
568
logger .info (f"Saving a full checkpoint at last step, step { curr_step } ." )
569
+ states = self ._flattened_model_states_sd ()
559
570
560
- save_with_gc (self . states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
571
+ save_with_gc (states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
561
572
562
573
def _should_save (self , curr_step : int , last_step : bool = False ) -> bool :
563
574
if not self .enable_checkpoint :
0 commit comments