@@ -578,6 +578,46 @@ def __init__(self):
578
578
579
579
manager .close ()
580
580
581
+ @mock .patch ("torch.distributed.get_rank" , return_value = 0 )
582
+ @mock .patch ("torchtitan.components.checkpoint.dcp.load" )
583
+ @mock .patch ("torchtitan.components.checkpoint.dcp.save" )
584
+ def test_verify_prefix (self , mock_save , mock_load , mock_rank ):
585
+ def fake_save (state_dict : dict , checkpoint_id : str ):
586
+ self .assertIn ("bias" , state_dict )
587
+ self .assertIn ("weight" , state_dict )
588
+ # No model prefix
589
+ self .assertNotIn ("model" , state_dict )
590
+ if "step-1" in checkpoint_id :
591
+ self .assertIn ("optimizer" , state_dict )
592
+ self .fake_save (state_dict , checkpoint_id )
593
+ else :
594
+ self .assertNotIn ("optimizer" , state_dict )
595
+ return
596
+
597
+ def fake_load (state_dict : dict , checkpoint_id = None ):
598
+ self .assertIn ("bias" , state_dict )
599
+ self .assertIn ("weight" , state_dict )
600
+ # No model prefix
601
+ self .assertNotIn ("model" , state_dict )
602
+ self .assertNotIn ("optimizer" , state_dict )
603
+
604
+ self .job_config .checkpoint .last_save_model_weights_only = True
605
+ manager = CheckpointManager (
606
+ dataloader = self .data_loader ,
607
+ model_parts = self .model_parts ,
608
+ optimizers = self .optimizers ,
609
+ lr_schedulers = self .lr_schedulers ,
610
+ states = self .states ,
611
+ job_config = self .job_config ,
612
+ ft_manager = self .ft_manager ,
613
+ )
614
+
615
+ mock_save .side_effect = fake_save
616
+ mock_load .side_effect = fake_load
617
+ manager .save (curr_step = 1 )
618
+ manager .save (curr_step = 2 , last_step = True )
619
+ manager .load (step = 1 )
620
+
581
621
582
622
if __name__ == "__main__" :
583
623
unittest .main ()
0 commit comments