Skip to content

[WIP][RFC] Always flatten model state_dict #1347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/convert_llama_to_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
logger.info(f"Writing to DCP at '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8)
DCP.save({"model": state_dict}, storage_writer=storage_writer)
DCP.save(state_dict, storage_writer=storage_writer)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def test_generate(
model.to_empty(device=device_type)
model.eval()

state_dict = {"model": model.state_dict()}
state_dict = model.state_dict()
for k in excluded_parameters_for_model_only:
state_dict["model"].pop(k, None)
state_dict.pop(k, None)

# Checkpoint Loading
begin = time.monotonic()
Expand Down
63 changes: 51 additions & 12 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchtitan.components.checkpoint import CheckpointManager, MODEL
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.config_manager import Checkpoint as CheckpointConfig


Expand Down Expand Up @@ -105,11 +105,11 @@ def setUp(self):

self.model_part = nn.Linear(2, 2)
self.model_parts = [self.model_part]
self.states = {"trainer": torch.tensor([1.2347])}
# TODO: Use a real OptimizerContainer here so that we can actually verify
# some optimizer.state_dict() behavior (e.g., the key being the parameter name.)
self.optimizers = FakeOptimizersContainer()
self.lr_schedulers = FakeLRSchedulersContainer()
self.states = {}
self.data_loader = FakeDataLoader()
self.ft_manager = DummyFTManager()

Expand Down Expand Up @@ -161,7 +161,7 @@ def fake_load(self, states: dict, checkpoint_id=None):
if key in states and hasattr(states[key], "load_state_dict"):
states[key].load_state_dict(val)
elif key in states and isinstance(states[key], torch.Tensor):
states[key] = val
states[key].copy_(val)

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.save")
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states={MODEL: self.model_part},
states=self.states,
job_config=self.job_config,
ft_manager=self.ft_manager,
)
Expand All @@ -373,7 +373,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states={MODEL: self.model_part},
states=self.states,
job_config=self.job_config,
ft_manager=self.ft_manager,
)
Expand Down Expand Up @@ -451,13 +451,12 @@ def test_ft_async_save_calls_async_wait(
ft_manager.manager.return_value = mock.Mock()
ft_manager.manager.participating_rank = mock.Mock(return_value=0)
ft_manager.enabled = True
states = {"trainer": torch.tensor([0])}
manager = CheckpointManager(
dataloader=self.data_loader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=states,
states=self.states,
job_config=job_config,
ft_manager=ft_manager,
)
Expand Down Expand Up @@ -571,17 +570,57 @@ def __init__(self):
self.assertEqual(mock_save.call_count, 1)
checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt")
saved_data = torch.load(checkpoint_path, weights_only=False)
model_state_dict = saved_data[MODEL]

# Verify that freqs_cis is NOT in the saved state dict
self.assertNotIn("freqs_cis", model_state_dict)
self.assertNotIn("freqs_cis", saved_data)
# Verify that other parameters ARE in the saved state dict
self.assertIn("weight", model_state_dict)
self.assertIn("bias", model_state_dict)
self.assertIn("other_param", model_state_dict)
self.assertIn("weight", saved_data)
self.assertIn("bias", saved_data)
self.assertIn("other_param", saved_data)

manager.close()

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
def fake_save(state_dict: dict, checkpoint_id: str):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
self.assertNotIn("model", state_dict)
if "step-1" in checkpoint_id:
self.assertIn("optimizer", state_dict)
self.fake_save(state_dict, checkpoint_id)
else:
self.assertNotIn("optimizer", state_dict)
return

def fake_load(state_dict: dict, checkpoint_id=None):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
self.assertNotIn("model", state_dict)
self.assertIn("optimizer", state_dict)

self.job_config.checkpoint.last_save_model_weights_only = True
self.job_config.checkpoint.initial_load_model_weights_only = False
manager = CheckpointManager(
dataloader=self.data_loader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=self.states,
job_config=self.job_config,
ft_manager=self.ft_manager,
)

mock_save.side_effect = fake_save
mock_load.side_effect = fake_load
manager.save(curr_step=1)
manager.save(curr_step=2, last_step=True)
manager.load(step=1)


if __name__ == "__main__":
unittest.main()
43 changes: 27 additions & 16 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,15 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
# freed until _async_wait()
if last_step:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
return

states = self._flattened_model_states_sd()
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
if self.stager is None:
self.stager = DefaultStager(StagingOptions(True, True, True, True))
result = dcp.async_save(
self.states,
states,
checkpoint_id=checkpoint_id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
Expand All @@ -361,11 +364,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.save_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
save_with_gc(states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()

logger.info(
Expand Down Expand Up @@ -502,6 +505,19 @@ def _ft_load(self) -> None:
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)

def _flattened_model_states_sd(
self, state_dict: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Flatten the model states into a single dictionary.

Note that other states, such as optimizer states, are not flattened.
"""
states = state_dict if state_dict is not None else self.states
sd = {k: v for k, v in states.items() if k != MODEL}
if MODEL in states:
sd.update(states[MODEL].state_dict())
return sd

def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""Determines which states to load for the given step.

Expand All @@ -516,8 +532,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""
# For the first step, we will only load the model weights.
if model_only:
sd = self.states[MODEL].state_dict()
return sd
return self.states[MODEL].state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean "model" still exists in state_dict as a key -- we only flatten it in the checkpoint (and its load and save)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Checkpointer, we still keep separate keys in self.states, like MODEL, OPTIMIZER. This will allow use to manipulate different state_dicts. This line use MODEL to access only the model state_dict but this line does not wrap the model state_dict, so there will be no model. prefix.


for exclude_key in self.exclude_from_loading:
if exclude_key not in self.states:
Expand All @@ -527,6 +542,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
}

states_to_load = self._flattened_model_states_sd(states_to_load)

if self.ft_manager:
states_to_load.pop(DATALOADER)

Expand All @@ -539,25 +556,19 @@ def _save_last_step(self, curr_step: int) -> None:
# current dtype is not the same as the export dtype at the end of the training.

if self.last_save_model_weights_only:
# We update self.states to keep the model only.
# After this update, self.states = {
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.
self.states = self.states[MODEL].state_dict()
states = self.states[MODEL].state_dict()

if self.export_dtype != torch.float32:
self.states = {
k: v.to(self.export_dtype) for k, v in self.states.items()
}
states = {k: v.to(self.export_dtype) for k, v in states.items()}
logger.info(
f"Saving a model weights only checkpoint in {self.export_dtype} "
f"at last step, step {curr_step}."
)
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
states = self._flattened_model_states_sd()

save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable_checkpoint:
Expand Down
Loading