Skip to content

Commit 4e9cb40

Browse files
committed
unittest
1 parent c3dc50a commit 4e9cb40

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,46 @@ def __init__(self):
580580

581581
manager.close()
582582

583+
@mock.patch("torch.distributed.get_rank", return_value=0)
584+
@mock.patch("torchtitan.components.checkpoint.dcp.load")
585+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
586+
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
587+
def fake_save(state_dict: dict, checkpoint_id: str):
588+
self.assertIn("bias", state_dict)
589+
self.assertIn("weight", state_dict)
590+
# No model prefix
591+
self.assertNotIn("model", state_dict)
592+
if "step-1" in checkpoint_id:
593+
self.assertIn("optimizer", state_dict)
594+
self.fake_save(state_dict, checkpoint_id)
595+
else:
596+
self.assertNotIn("optimizer", state_dict)
597+
return
598+
599+
def fake_load(state_dict: dict, checkpoint_id=None):
600+
self.assertIn("bias", state_dict)
601+
self.assertIn("weight", state_dict)
602+
# No model prefix
603+
self.assertNotIn("model", state_dict)
604+
self.assertNotIn("optimizer", state_dict)
605+
606+
self.job_config.checkpoint.last_save_model_weights_only = True
607+
manager = CheckpointManager(
608+
dataloader=self.data_loader,
609+
model_parts=self.model_parts,
610+
optimizers=self.optimizers,
611+
lr_schedulers=self.lr_schedulers,
612+
states=self.states,
613+
job_config=self.job_config,
614+
ft_manager=self.ft_manager,
615+
)
616+
617+
mock_save.side_effect = fake_save
618+
mock_load.side_effect = fake_load
619+
manager.save(curr_step=1)
620+
manager.save(curr_step=2, last_step=True)
621+
manager.load(step=1)
622+
583623

584624
if __name__ == "__main__":
585625
unittest.main()

0 commit comments

Comments
 (0)