From c3dc50a1c0615f0d5d7165f93064189e8c75f505 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 26 Jun 2025 10:51:33 -0700 Subject: [PATCH 1/5] [WIP][RFC] Always flatten model state_dict --- scripts/convert_llama_to_dcp.py | 2 +- scripts/generate/test_generate.py | 4 +-- tests/unit_tests/test_checkpoint.py | 22 +++++++-------- torchtitan/components/checkpoint.py | 43 ++++++++++++++++++----------- 4 files changed, 40 insertions(+), 31 deletions(-) diff --git a/scripts/convert_llama_to_dcp.py b/scripts/convert_llama_to_dcp.py index cac1a908e..02f371c0c 100644 --- a/scripts/convert_llama_to_dcp.py +++ b/scripts/convert_llama_to_dcp.py @@ -125,7 +125,7 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int): logger.info(f"Writing to DCP at '{output_dir}'") output_dir.mkdir(parents=True, exist_ok=True) storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8) - DCP.save({"model": state_dict}, storage_writer=storage_writer) + DCP.save(state_dict, storage_writer=storage_writer) if __name__ == "__main__": diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 636d10a51..0a1649ea4 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -141,9 +141,9 @@ def test_generate( model.to_empty(device=device_type) model.eval() - state_dict = {"model": model.state_dict()} + state_dict = model.state_dict() for k in excluded_parameters_for_model_only: - state_dict["model"].pop(k, None) + state_dict.pop(k, None) # Checkpoint Loading begin = time.monotonic() diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 838500dc0..bff8f4c48 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -15,7 +15,7 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader -from torchtitan.components.checkpoint import CheckpointManager, MODEL +from torchtitan.components.checkpoint import CheckpointManager from torchtitan.config_manager import Checkpoint as CheckpointConfig @@ -105,11 +105,11 @@ def setUp(self): self.model_part = nn.Linear(2, 2) self.model_parts = [self.model_part] + self.states = {"trainer": torch.tensor([1.2347])} # TODO: Use a real OptimizerContainer here so that we can actually verify # some optimizer.state_dict() behavior (e.g., the key being the parameter name.) self.optimizers = FakeOptimizersContainer() self.lr_schedulers = FakeLRSchedulersContainer() - self.states = {} self.data_loader = FakeDataLoader() self.ft_manager = DummyFTManager() @@ -161,7 +161,7 @@ def fake_load(self, states: dict, checkpoint_id=None): if key in states and hasattr(states[key], "load_state_dict"): states[key].load_state_dict(val) elif key in states and isinstance(states[key], torch.Tensor): - states[key] = val + states[key].copy_(val) @mock.patch("torch.distributed.get_rank", return_value=0) @mock.patch("torchtitan.components.checkpoint.dcp.save") @@ -354,7 +354,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only( model_parts=self.model_parts, optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, - states={MODEL: self.model_part}, + states=self.states, job_config=self.job_config, ft_manager=self.ft_manager, ) @@ -373,7 +373,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only( model_parts=self.model_parts, optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, - states={MODEL: self.model_part}, + states=self.states, job_config=self.job_config, ft_manager=self.ft_manager, ) @@ -451,13 +451,12 @@ def test_ft_async_save_calls_async_wait( ft_manager.manager.return_value = mock.Mock() ft_manager.manager.participating_rank = mock.Mock(return_value=0) ft_manager.enabled = True - states = {"trainer": torch.tensor([0])} manager = CheckpointManager( dataloader=self.data_loader, model_parts=self.model_parts, optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, - states=states, + states=self.states, job_config=job_config, ft_manager=ft_manager, ) @@ -571,14 +570,13 @@ def __init__(self): self.assertEqual(mock_save.call_count, 1) checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt") saved_data = torch.load(checkpoint_path, weights_only=False) - model_state_dict = saved_data[MODEL] # Verify that freqs_cis is NOT in the saved state dict - self.assertNotIn("freqs_cis", model_state_dict) + self.assertNotIn("freqs_cis", saved_data) # Verify that other parameters ARE in the saved state dict - self.assertIn("weight", model_state_dict) - self.assertIn("bias", model_state_dict) - self.assertIn("other_param", model_state_dict) + self.assertIn("weight", saved_data) + self.assertIn("bias", saved_data) + self.assertIn("other_param", saved_data) manager.close() diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fc0e1ab39..ff055cbe7 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -345,12 +345,15 @@ def save(self, curr_step: int, last_step: bool = False) -> None: # freed until _async_wait() if last_step: self._save_last_step(curr_step) - elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + return + + states = self._flattened_model_states_sd() + if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: self.stager = DefaultStager(StagingOptions(True, True, True, True)) result = dcp.async_save( - self.states, + states, checkpoint_id=checkpoint_id, process_group=self.pg, async_checkpointer_type=AsyncCheckpointerType.PROCESS, @@ -361,11 +364,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None: elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") self.save_future = dcp.async_save( - self.states, checkpoint_id=checkpoint_id, process_group=self.pg + states, checkpoint_id=checkpoint_id, process_group=self.pg ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - save_with_gc(self.states, checkpoint_id=checkpoint_id) + save_with_gc(states, checkpoint_id=checkpoint_id) self._purge_stale_checkpoints() logger.info( @@ -502,6 +505,19 @@ def _ft_load(self) -> None: f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." ) + def _flattened_model_states_sd( + self, state_dict: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Flatten the model states into a single dictionary. + + Note that other states, such as optimizer states, are not flattened. + """ + states = state_dict if state_dict is not None else self.states + sd = {k: v for k, v in states.items() if k != MODEL} + if MODEL in states: + sd.update(states[MODEL].state_dict()) + return sd + def _states_to_load(self, model_only: bool) -> dict[str, Any]: """Determines which states to load for the given step. @@ -516,8 +532,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: """ # For the first step, we will only load the model weights. if model_only: - sd = self.states[MODEL].state_dict() - return sd + return self.states[MODEL].state_dict() for exclude_key in self.exclude_from_loading: if exclude_key not in self.states: @@ -527,6 +542,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: k: v for k, v in self.states.items() if k not in self.exclude_from_loading } + states_to_load = self._flattened_model_states_sd(states_to_load) + if self.ft_manager: states_to_load.pop(DATALOADER) @@ -539,25 +556,19 @@ def _save_last_step(self, curr_step: int) -> None: # current dtype is not the same as the export dtype at the end of the training. if self.last_save_model_weights_only: - # We update self.states to keep the model only. - # After this update, self.states = { - # 'tok_embeddings.weight':..., - # 'layers.0.attention.wq.weight': ... - # }. - self.states = self.states[MODEL].state_dict() + states = self.states[MODEL].state_dict() if self.export_dtype != torch.float32: - self.states = { - k: v.to(self.export_dtype) for k, v in self.states.items() - } + states = {k: v.to(self.export_dtype) for k, v in states.items()} logger.info( f"Saving a model weights only checkpoint in {self.export_dtype} " f"at last step, step {curr_step}." ) else: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") + states = self._flattened_model_states_sd() - save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) + save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: if not self.enable_checkpoint: From 4e9cb4050f027565f07670a4c58b895ed886354b Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 27 Jun 2025 15:20:36 -0700 Subject: [PATCH 2/5] unittest --- tests/unit_tests/test_checkpoint.py | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index bff8f4c48..bb3fec281 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -580,6 +580,46 @@ def __init__(self): manager.close() + @mock.patch("torch.distributed.get_rank", return_value=0) + @mock.patch("torchtitan.components.checkpoint.dcp.load") + @mock.patch("torchtitan.components.checkpoint.dcp.save") + def test_verify_prefix(self, mock_save, mock_load, mock_rank): + def fake_save(state_dict: dict, checkpoint_id: str): + self.assertIn("bias", state_dict) + self.assertIn("weight", state_dict) + # No model prefix + self.assertNotIn("model", state_dict) + if "step-1" in checkpoint_id: + self.assertIn("optimizer", state_dict) + self.fake_save(state_dict, checkpoint_id) + else: + self.assertNotIn("optimizer", state_dict) + return + + def fake_load(state_dict: dict, checkpoint_id=None): + self.assertIn("bias", state_dict) + self.assertIn("weight", state_dict) + # No model prefix + self.assertNotIn("model", state_dict) + self.assertNotIn("optimizer", state_dict) + + self.job_config.checkpoint.last_save_model_weights_only = True + manager = CheckpointManager( + dataloader=self.data_loader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states=self.states, + job_config=self.job_config, + ft_manager=self.ft_manager, + ) + + mock_save.side_effect = fake_save + mock_load.side_effect = fake_load + manager.save(curr_step=1) + manager.save(curr_step=2, last_step=True) + manager.load(step=1) + if __name__ == "__main__": unittest.main() From 7bb339536aa47ed7124f2003965c30f5f121c50b Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 30 Jun 2025 23:37:46 -0700 Subject: [PATCH 3/5] lint --- torchtitan/components/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ff055cbe7..7d88bbf78 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -52,7 +52,7 @@ class AsyncMode(str, enum.Enum): # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 -excluded_parameters_for_model_only = {"freqs_cis"} +excluded_parameters_for_model_only = {} class ModelWrapper(Stateful): From f0d0dd71d7991deab54a2d6b4d00846fdb8bdf50 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 1 Jul 2025 22:40:46 -0700 Subject: [PATCH 4/5] misc --- torchtitan/components/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 7d88bbf78..ff055cbe7 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -52,7 +52,7 @@ class AsyncMode(str, enum.Enum): # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 -excluded_parameters_for_model_only = {} +excluded_parameters_for_model_only = {"freqs_cis"} class ModelWrapper(Stateful): From f1900870cfb1d4897916281bbd7e12eedced1915 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 2 Jul 2025 11:33:35 -0700 Subject: [PATCH 5/5] misc --- tests/unit_tests/test_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index bb3fec281..3317a51fe 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -601,9 +601,10 @@ def fake_load(state_dict: dict, checkpoint_id=None): self.assertIn("weight", state_dict) # No model prefix self.assertNotIn("model", state_dict) - self.assertNotIn("optimizer", state_dict) + self.assertIn("optimizer", state_dict) self.job_config.checkpoint.last_save_model_weights_only = True + self.job_config.checkpoint.initial_load_model_weights_only = False manager = CheckpointManager( dataloader=self.data_loader, model_parts=self.model_parts,