@@ -580,6 +580,46 @@ def __init__(self):
580
580
581
581
manager .close ()
582
582
583
+ @mock .patch ("torch.distributed.get_rank" , return_value = 0 )
584
+ @mock .patch ("torchtitan.components.checkpoint.dcp.load" )
585
+ @mock .patch ("torchtitan.components.checkpoint.dcp.save" )
586
+ def test_verify_prefix (self , mock_save , mock_load , mock_rank ):
587
+ def fake_save (state_dict : dict , checkpoint_id : str ):
588
+ self .assertIn ("bias" , state_dict )
589
+ self .assertIn ("weight" , state_dict )
590
+ # No model prefix
591
+ self .assertNotIn ("model" , state_dict )
592
+ if "step-1" in checkpoint_id :
593
+ self .assertIn ("optimizer" , state_dict )
594
+ self .fake_save (state_dict , checkpoint_id )
595
+ else :
596
+ self .assertNotIn ("optimizer" , state_dict )
597
+ return
598
+
599
+ def fake_load (state_dict : dict , checkpoint_id = None ):
600
+ self .assertIn ("bias" , state_dict )
601
+ self .assertIn ("weight" , state_dict )
602
+ # No model prefix
603
+ self .assertNotIn ("model" , state_dict )
604
+ self .assertNotIn ("optimizer" , state_dict )
605
+
606
+ self .job_config .checkpoint .last_save_model_weights_only = True
607
+ manager = CheckpointManager (
608
+ dataloader = self .data_loader ,
609
+ model_parts = self .model_parts ,
610
+ optimizers = self .optimizers ,
611
+ lr_schedulers = self .lr_schedulers ,
612
+ states = self .states ,
613
+ job_config = self .job_config ,
614
+ ft_manager = self .ft_manager ,
615
+ )
616
+
617
+ mock_save .side_effect = fake_save
618
+ mock_load .side_effect = fake_load
619
+ manager .save (curr_step = 1 )
620
+ manager .save (curr_step = 2 , last_step = True )
621
+ manager .load (step = 1 )
622
+
583
623
584
624
if __name__ == "__main__" :
585
625
unittest .main ()
0 commit comments