Skip to content

Commit 7d5f3cc

Browse files
authored
[WIP][RFC] Always flatten model state_dict (#1347)
The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior. This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the `"MODEL."` prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case. While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict. Also see the discussion in #1321 (comment) This is the pseudo code for the current state: ``` if model_only: state_dict = model.state_dict() else: state_dict = { "MODEL": model, "OPTIMIZER": optimizer, ... } } ``` This is the pseudo code after this PR is landed: ``` state_dict = model.state_dict() if not model_only: state_dict.update( {"OPTIMIZER": optimizer} ... ) ``` FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42 ![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
1 parent 4aa9fde commit 7d5f3cc

File tree

4 files changed

+81
-31
lines changed

4 files changed

+81
-31
lines changed

scripts/convert_llama_to_dcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
125125
logger.info(f"Writing to DCP at '{output_dir}'")
126126
output_dir.mkdir(parents=True, exist_ok=True)
127127
storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8)
128-
DCP.save({"model": state_dict}, storage_writer=storage_writer)
128+
DCP.save(state_dict, storage_writer=storage_writer)
129129

130130

131131
if __name__ == "__main__":

scripts/generate/test_generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def test_generate(
141141
model.to_empty(device=device_type)
142142
model.eval()
143143

144-
state_dict = {"model": model.state_dict()}
144+
state_dict = model.state_dict()
145145
for k in excluded_parameters_for_model_only:
146-
state_dict["model"].pop(k, None)
146+
state_dict.pop(k, None)
147147

148148
# Checkpoint Loading
149149
begin = time.monotonic()

tests/unit_tests/test_checkpoint.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
from torch.utils.data import DataLoader
18-
from torchtitan.components.checkpoint import CheckpointManager, MODEL
18+
from torchtitan.components.checkpoint import CheckpointManager
1919
from torchtitan.config_manager import Checkpoint as CheckpointConfig
2020

2121

@@ -105,11 +105,11 @@ def setUp(self):
105105

106106
self.model_part = nn.Linear(2, 2)
107107
self.model_parts = [self.model_part]
108+
self.states = {"trainer": torch.tensor([1.2347])}
108109
# TODO: Use a real OptimizerContainer here so that we can actually verify
109110
# some optimizer.state_dict() behavior (e.g., the key being the parameter name.)
110111
self.optimizers = FakeOptimizersContainer()
111112
self.lr_schedulers = FakeLRSchedulersContainer()
112-
self.states = {}
113113
self.data_loader = FakeDataLoader()
114114
self.ft_manager = DummyFTManager()
115115

@@ -161,7 +161,7 @@ def fake_load(self, states: dict, checkpoint_id=None):
161161
if key in states and hasattr(states[key], "load_state_dict"):
162162
states[key].load_state_dict(val)
163163
elif key in states and isinstance(states[key], torch.Tensor):
164-
states[key] = val
164+
states[key].copy_(val)
165165

166166
@mock.patch("torch.distributed.get_rank", return_value=0)
167167
@mock.patch("torchtitan.components.checkpoint.dcp.save")
@@ -354,7 +354,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
354354
model_parts=self.model_parts,
355355
optimizers=self.optimizers,
356356
lr_schedulers=self.lr_schedulers,
357-
states={MODEL: self.model_part},
357+
states=self.states,
358358
job_config=self.job_config,
359359
ft_manager=self.ft_manager,
360360
)
@@ -373,7 +373,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
373373
model_parts=self.model_parts,
374374
optimizers=self.optimizers,
375375
lr_schedulers=self.lr_schedulers,
376-
states={MODEL: self.model_part},
376+
states=self.states,
377377
job_config=self.job_config,
378378
ft_manager=self.ft_manager,
379379
)
@@ -451,13 +451,12 @@ def test_ft_async_save_calls_async_wait(
451451
ft_manager.manager.return_value = mock.Mock()
452452
ft_manager.manager.participating_rank = mock.Mock(return_value=0)
453453
ft_manager.enabled = True
454-
states = {"trainer": torch.tensor([0])}
455454
manager = CheckpointManager(
456455
dataloader=self.data_loader,
457456
model_parts=self.model_parts,
458457
optimizers=self.optimizers,
459458
lr_schedulers=self.lr_schedulers,
460-
states=states,
459+
states=self.states,
461460
job_config=job_config,
462461
ft_manager=ft_manager,
463462
)
@@ -571,17 +570,57 @@ def __init__(self):
571570
self.assertEqual(mock_save.call_count, 1)
572571
checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt")
573572
saved_data = torch.load(checkpoint_path, weights_only=False)
574-
model_state_dict = saved_data[MODEL]
575573

576574
# Verify that freqs_cis is NOT in the saved state dict
577-
self.assertNotIn("freqs_cis", model_state_dict)
575+
self.assertNotIn("freqs_cis", saved_data)
578576
# Verify that other parameters ARE in the saved state dict
579-
self.assertIn("weight", model_state_dict)
580-
self.assertIn("bias", model_state_dict)
581-
self.assertIn("other_param", model_state_dict)
577+
self.assertIn("weight", saved_data)
578+
self.assertIn("bias", saved_data)
579+
self.assertIn("other_param", saved_data)
582580

583581
manager.close()
584582

583+
@mock.patch("torch.distributed.get_rank", return_value=0)
584+
@mock.patch("torchtitan.components.checkpoint.dcp.load")
585+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
586+
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
587+
def fake_save(state_dict: dict, checkpoint_id: str):
588+
self.assertIn("bias", state_dict)
589+
self.assertIn("weight", state_dict)
590+
# No model prefix
591+
self.assertNotIn("model", state_dict)
592+
if "step-1" in checkpoint_id:
593+
self.assertIn("optimizer", state_dict)
594+
self.fake_save(state_dict, checkpoint_id)
595+
else:
596+
self.assertNotIn("optimizer", state_dict)
597+
return
598+
599+
def fake_load(state_dict: dict, checkpoint_id=None):
600+
self.assertIn("bias", state_dict)
601+
self.assertIn("weight", state_dict)
602+
# No model prefix
603+
self.assertNotIn("model", state_dict)
604+
self.assertIn("optimizer", state_dict)
605+
606+
self.job_config.checkpoint.last_save_model_weights_only = True
607+
self.job_config.checkpoint.initial_load_model_weights_only = False
608+
manager = CheckpointManager(
609+
dataloader=self.data_loader,
610+
model_parts=self.model_parts,
611+
optimizers=self.optimizers,
612+
lr_schedulers=self.lr_schedulers,
613+
states=self.states,
614+
job_config=self.job_config,
615+
ft_manager=self.ft_manager,
616+
)
617+
618+
mock_save.side_effect = fake_save
619+
mock_load.side_effect = fake_load
620+
manager.save(curr_step=1)
621+
manager.save(curr_step=2, last_step=True)
622+
manager.load(step=1)
623+
585624

586625
if __name__ == "__main__":
587626
unittest.main()

torchtitan/components/checkpoint.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,15 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
345345
# freed until _async_wait()
346346
if last_step:
347347
self._save_last_step(curr_step)
348-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
348+
return
349+
350+
states = self._flattened_model_states_sd()
351+
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
349352
GarbageCollection.collect("GC collection invoked by checkpointer.")
350353
if self.stager is None:
351354
self.stager = DefaultStager(StagingOptions(True, True, True, True))
352355
result = dcp.async_save(
353-
self.states,
356+
states,
354357
checkpoint_id=checkpoint_id,
355358
process_group=self.pg,
356359
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
@@ -361,11 +364,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
361364
elif self.async_mode == AsyncMode.ASYNC:
362365
GarbageCollection.collect("GC collection invoked by checkpointer.")
363366
self.save_future = dcp.async_save(
364-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
367+
states, checkpoint_id=checkpoint_id, process_group=self.pg
365368
)
366369
GarbageCollection.collect("GC collection invoked by checkpointer.")
367370
else:
368-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
371+
save_with_gc(states, checkpoint_id=checkpoint_id)
369372
self._purge_stale_checkpoints()
370373

371374
logger.info(
@@ -502,6 +505,19 @@ def _ft_load(self) -> None:
502505
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
503506
)
504507

508+
def _flattened_model_states_sd(
509+
self, state_dict: dict[str, Any] | None = None
510+
) -> dict[str, Any]:
511+
"""Flatten the model states into a single dictionary.
512+
513+
Note that other states, such as optimizer states, are not flattened.
514+
"""
515+
states = state_dict if state_dict is not None else self.states
516+
sd = {k: v for k, v in states.items() if k != MODEL}
517+
if MODEL in states:
518+
sd.update(states[MODEL].state_dict())
519+
return sd
520+
505521
def _states_to_load(self, model_only: bool) -> dict[str, Any]:
506522
"""Determines which states to load for the given step.
507523
@@ -516,8 +532,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
516532
"""
517533
# For the first step, we will only load the model weights.
518534
if model_only:
519-
sd = self.states[MODEL].state_dict()
520-
return sd
535+
return self.states[MODEL].state_dict()
521536

522537
for exclude_key in self.exclude_from_loading:
523538
if exclude_key not in self.states:
@@ -527,6 +542,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
527542
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
528543
}
529544

545+
states_to_load = self._flattened_model_states_sd(states_to_load)
546+
530547
if self.ft_manager:
531548
states_to_load.pop(DATALOADER)
532549

@@ -539,25 +556,19 @@ def _save_last_step(self, curr_step: int) -> None:
539556
# current dtype is not the same as the export dtype at the end of the training.
540557

541558
if self.last_save_model_weights_only:
542-
# We update self.states to keep the model only.
543-
# After this update, self.states = {
544-
# 'tok_embeddings.weight':...,
545-
# 'layers.0.attention.wq.weight': ...
546-
# }.
547-
self.states = self.states[MODEL].state_dict()
559+
states = self.states[MODEL].state_dict()
548560

549561
if self.export_dtype != torch.float32:
550-
self.states = {
551-
k: v.to(self.export_dtype) for k, v in self.states.items()
552-
}
562+
states = {k: v.to(self.export_dtype) for k, v in states.items()}
553563
logger.info(
554564
f"Saving a model weights only checkpoint in {self.export_dtype} "
555565
f"at last step, step {curr_step}."
556566
)
557567
else:
558568
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
569+
states = self._flattened_model_states_sd()
559570

560-
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
571+
save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))
561572

562573
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
563574
if not self.enable_checkpoint:

0 commit comments

Comments
 (0)