Skip to content

Commit c3dc50a

Browse files
committed
[WIP][RFC] Always flatten model state_dict
1 parent c08c9d4 commit c3dc50a

File tree

4 files changed

+40
-31
lines changed

4 files changed

+40
-31
lines changed

scripts/convert_llama_to_dcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
125125
logger.info(f"Writing to DCP at '{output_dir}'")
126126
output_dir.mkdir(parents=True, exist_ok=True)
127127
storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8)
128-
DCP.save({"model": state_dict}, storage_writer=storage_writer)
128+
DCP.save(state_dict, storage_writer=storage_writer)
129129

130130

131131
if __name__ == "__main__":

scripts/generate/test_generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def test_generate(
141141
model.to_empty(device=device_type)
142142
model.eval()
143143

144-
state_dict = {"model": model.state_dict()}
144+
state_dict = model.state_dict()
145145
for k in excluded_parameters_for_model_only:
146-
state_dict["model"].pop(k, None)
146+
state_dict.pop(k, None)
147147

148148
# Checkpoint Loading
149149
begin = time.monotonic()

tests/unit_tests/test_checkpoint.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
from torch.utils.data import DataLoader
18-
from torchtitan.components.checkpoint import CheckpointManager, MODEL
18+
from torchtitan.components.checkpoint import CheckpointManager
1919
from torchtitan.config_manager import Checkpoint as CheckpointConfig
2020

2121

@@ -105,11 +105,11 @@ def setUp(self):
105105

106106
self.model_part = nn.Linear(2, 2)
107107
self.model_parts = [self.model_part]
108+
self.states = {"trainer": torch.tensor([1.2347])}
108109
# TODO: Use a real OptimizerContainer here so that we can actually verify
109110
# some optimizer.state_dict() behavior (e.g., the key being the parameter name.)
110111
self.optimizers = FakeOptimizersContainer()
111112
self.lr_schedulers = FakeLRSchedulersContainer()
112-
self.states = {}
113113
self.data_loader = FakeDataLoader()
114114
self.ft_manager = DummyFTManager()
115115

@@ -161,7 +161,7 @@ def fake_load(self, states: dict, checkpoint_id=None):
161161
if key in states and hasattr(states[key], "load_state_dict"):
162162
states[key].load_state_dict(val)
163163
elif key in states and isinstance(states[key], torch.Tensor):
164-
states[key] = val
164+
states[key].copy_(val)
165165

166166
@mock.patch("torch.distributed.get_rank", return_value=0)
167167
@mock.patch("torchtitan.components.checkpoint.dcp.save")
@@ -354,7 +354,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
354354
model_parts=self.model_parts,
355355
optimizers=self.optimizers,
356356
lr_schedulers=self.lr_schedulers,
357-
states={MODEL: self.model_part},
357+
states=self.states,
358358
job_config=self.job_config,
359359
ft_manager=self.ft_manager,
360360
)
@@ -373,7 +373,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
373373
model_parts=self.model_parts,
374374
optimizers=self.optimizers,
375375
lr_schedulers=self.lr_schedulers,
376-
states={MODEL: self.model_part},
376+
states=self.states,
377377
job_config=self.job_config,
378378
ft_manager=self.ft_manager,
379379
)
@@ -451,13 +451,12 @@ def test_ft_async_save_calls_async_wait(
451451
ft_manager.manager.return_value = mock.Mock()
452452
ft_manager.manager.participating_rank = mock.Mock(return_value=0)
453453
ft_manager.enabled = True
454-
states = {"trainer": torch.tensor([0])}
455454
manager = CheckpointManager(
456455
dataloader=self.data_loader,
457456
model_parts=self.model_parts,
458457
optimizers=self.optimizers,
459458
lr_schedulers=self.lr_schedulers,
460-
states=states,
459+
states=self.states,
461460
job_config=job_config,
462461
ft_manager=ft_manager,
463462
)
@@ -571,14 +570,13 @@ def __init__(self):
571570
self.assertEqual(mock_save.call_count, 1)
572571
checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt")
573572
saved_data = torch.load(checkpoint_path, weights_only=False)
574-
model_state_dict = saved_data[MODEL]
575573

576574
# Verify that freqs_cis is NOT in the saved state dict
577-
self.assertNotIn("freqs_cis", model_state_dict)
575+
self.assertNotIn("freqs_cis", saved_data)
578576
# Verify that other parameters ARE in the saved state dict
579-
self.assertIn("weight", model_state_dict)
580-
self.assertIn("bias", model_state_dict)
581-
self.assertIn("other_param", model_state_dict)
577+
self.assertIn("weight", saved_data)
578+
self.assertIn("bias", saved_data)
579+
self.assertIn("other_param", saved_data)
582580

583581
manager.close()
584582

torchtitan/components/checkpoint.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,15 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
345345
# freed until _async_wait()
346346
if last_step:
347347
self._save_last_step(curr_step)
348-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
348+
return
349+
350+
states = self._flattened_model_states_sd()
351+
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
349352
GarbageCollection.collect("GC collection invoked by checkpointer.")
350353
if self.stager is None:
351354
self.stager = DefaultStager(StagingOptions(True, True, True, True))
352355
result = dcp.async_save(
353-
self.states,
356+
states,
354357
checkpoint_id=checkpoint_id,
355358
process_group=self.pg,
356359
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
@@ -361,11 +364,11 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
361364
elif self.async_mode == AsyncMode.ASYNC:
362365
GarbageCollection.collect("GC collection invoked by checkpointer.")
363366
self.save_future = dcp.async_save(
364-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
367+
states, checkpoint_id=checkpoint_id, process_group=self.pg
365368
)
366369
GarbageCollection.collect("GC collection invoked by checkpointer.")
367370
else:
368-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
371+
save_with_gc(states, checkpoint_id=checkpoint_id)
369372
self._purge_stale_checkpoints()
370373

371374
logger.info(
@@ -502,6 +505,19 @@ def _ft_load(self) -> None:
502505
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
503506
)
504507

508+
def _flattened_model_states_sd(
509+
self, state_dict: dict[str, Any] | None = None
510+
) -> dict[str, Any]:
511+
"""Flatten the model states into a single dictionary.
512+
513+
Note that other states, such as optimizer states, are not flattened.
514+
"""
515+
states = state_dict if state_dict is not None else self.states
516+
sd = {k: v for k, v in states.items() if k != MODEL}
517+
if MODEL in states:
518+
sd.update(states[MODEL].state_dict())
519+
return sd
520+
505521
def _states_to_load(self, model_only: bool) -> dict[str, Any]:
506522
"""Determines which states to load for the given step.
507523
@@ -516,8 +532,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
516532
"""
517533
# For the first step, we will only load the model weights.
518534
if model_only:
519-
sd = self.states[MODEL].state_dict()
520-
return sd
535+
return self.states[MODEL].state_dict()
521536

522537
for exclude_key in self.exclude_from_loading:
523538
if exclude_key not in self.states:
@@ -527,6 +542,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
527542
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
528543
}
529544

545+
states_to_load = self._flattened_model_states_sd(states_to_load)
546+
530547
if self.ft_manager:
531548
states_to_load.pop(DATALOADER)
532549

@@ -539,25 +556,19 @@ def _save_last_step(self, curr_step: int) -> None:
539556
# current dtype is not the same as the export dtype at the end of the training.
540557

541558
if self.last_save_model_weights_only:
542-
# We update self.states to keep the model only.
543-
# After this update, self.states = {
544-
# 'tok_embeddings.weight':...,
545-
# 'layers.0.attention.wq.weight': ...
546-
# }.
547-
self.states = self.states[MODEL].state_dict()
559+
states = self.states[MODEL].state_dict()
548560

549561
if self.export_dtype != torch.float32:
550-
self.states = {
551-
k: v.to(self.export_dtype) for k, v in self.states.items()
552-
}
562+
states = {k: v.to(self.export_dtype) for k, v in states.items()}
553563
logger.info(
554564
f"Saving a model weights only checkpoint in {self.export_dtype} "
555565
f"at last step, step {curr_step}."
556566
)
557567
else:
558568
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
569+
states = self._flattened_model_states_sd()
559570

560-
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
571+
save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))
561572

562573
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
563574
if not self.enable_checkpoint:

0 commit comments

Comments
 (0)