Skip to content

Commit 83ff853

Browse files
committed
unittest
1 parent 4f05370 commit 83ff853

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,46 @@ def __init__(self):
578578

579579
manager.close()
580580

581+
@mock.patch("torch.distributed.get_rank", return_value=0)
582+
@mock.patch("torchtitan.components.checkpoint.dcp.load")
583+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
584+
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
585+
def fake_save(state_dict: dict, checkpoint_id: str):
586+
self.assertIn("bias", state_dict)
587+
self.assertIn("weight", state_dict)
588+
# No model prefix
589+
self.assertNotIn("model", state_dict)
590+
if "step-1" in checkpoint_id:
591+
self.assertIn("optimizer", state_dict)
592+
self.fake_save(state_dict, checkpoint_id)
593+
else:
594+
self.assertNotIn("optimizer", state_dict)
595+
return
596+
597+
def fake_load(state_dict: dict, checkpoint_id=None):
598+
self.assertIn("bias", state_dict)
599+
self.assertIn("weight", state_dict)
600+
# No model prefix
601+
self.assertNotIn("model", state_dict)
602+
self.assertNotIn("optimizer", state_dict)
603+
604+
self.job_config.checkpoint.last_save_model_weights_only = True
605+
manager = CheckpointManager(
606+
dataloader=self.data_loader,
607+
model_parts=self.model_parts,
608+
optimizers=self.optimizers,
609+
lr_schedulers=self.lr_schedulers,
610+
states=self.states,
611+
job_config=self.job_config,
612+
ft_manager=self.ft_manager,
613+
)
614+
615+
mock_save.side_effect = fake_save
616+
mock_load.side_effect = fake_load
617+
manager.save(curr_step=1)
618+
manager.save(curr_step=2, last_step=True)
619+
manager.load(step=1)
620+
581621

582622
if __name__ == "__main__":
583623
unittest.main()

torchtitan/components/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,9 @@ def _save_last_step(self, curr_step: int) -> None:
632632
else:
633633
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
634634

635-
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
635+
save_with_gc(
636+
self._flattend_model_states_sd(), checkpoint_id=self._create_checkpoint_id(curr_step)
637+
)
636638

637639
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
638640
if not self.enable_checkpoint:

0 commit comments

Comments
 (0)