Skip to content

Commit 4f05370

Browse files
committed
[WIP][RFC] Always flatten model state_dict
1 parent aefe15a commit 4f05370

File tree

4 files changed

+35
-19
lines changed

4 files changed

+35
-19
lines changed

scripts/convert_llama_to_dcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
125125
logger.info(f"Writing to DCP at '{output_dir}'")
126126
output_dir.mkdir(parents=True, exist_ok=True)
127127
storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8)
128-
DCP.save({"model": state_dict}, storage_writer=storage_writer)
128+
DCP.save(state_dict, storage_writer=storage_writer)
129129

130130

131131
if __name__ == "__main__":

scripts/generate/test_generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def test_generate(
141141
model.to_empty(device=device_type)
142142
model.eval()
143143

144-
state_dict = {"model": model.state_dict()}
144+
state_dict = model.state_dict()
145145
for k in excluded_parameters_for_model_only:
146-
state_dict["model"].pop(k, None)
146+
state_dict.pop(k, None)
147147

148148
# Checkpoint Loading
149149
begin = time.monotonic()

tests/unit_tests/test_checkpoint.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
from torch.utils.data import DataLoader
18-
from torchtitan.components.checkpoint import CheckpointManager, MODEL
18+
from torchtitan.components.checkpoint import CheckpointManager
1919
from torchtitan.config_manager import Checkpoint as CheckpointConfig
2020

2121

@@ -105,11 +105,11 @@ def setUp(self):
105105

106106
self.model_part = nn.Linear(2, 2)
107107
self.model_parts = [self.model_part]
108+
self.states = {"trainer": torch.tensor([1.2347])}
108109
# TODO: Use a real OptimizerContainer here so that we can actually verify
109110
# some optimizer.state_dict() behavior (e.g., the key being the parameter name.)
110111
self.optimizers = FakeOptimizersContainer()
111112
self.lr_schedulers = FakeLRSchedulersContainer()
112-
self.states = {}
113113
self.data_loader = FakeDataLoader()
114114
self.ft_manager = DummyFTManager()
115115

@@ -161,7 +161,7 @@ def fake_load(self, states: dict, checkpoint_id=None):
161161
if key in states and hasattr(states[key], "load_state_dict"):
162162
states[key].load_state_dict(val)
163163
elif key in states and isinstance(states[key], torch.Tensor):
164-
states[key] = val
164+
states[key].copy_(val)
165165

166166
@mock.patch("torch.distributed.get_rank", return_value=0)
167167
@mock.patch("torchtitan.components.checkpoint.dcp.save")
@@ -354,7 +354,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
354354
model_parts=self.model_parts,
355355
optimizers=self.optimizers,
356356
lr_schedulers=self.lr_schedulers,
357-
states={MODEL: self.model_part},
357+
states=self.states,
358358
job_config=self.job_config,
359359
ft_manager=self.ft_manager,
360360
)
@@ -373,7 +373,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
373373
model_parts=self.model_parts,
374374
optimizers=self.optimizers,
375375
lr_schedulers=self.lr_schedulers,
376-
states={MODEL: self.model_part},
376+
states=self.states,
377377
job_config=self.job_config,
378378
ft_manager=self.ft_manager,
379379
)
@@ -449,13 +449,12 @@ def test_ft_async_save_calls_async_wait(
449449
job_config.checkpoint.async_mode = "disabled"
450450
ft_manager = mock.Mock()
451451
ft_manager.enabled = True
452-
states = {"trainer": torch.tensor([0])}
453452
manager = CheckpointManager(
454453
dataloader=self.data_loader,
455454
model_parts=self.model_parts,
456455
optimizers=self.optimizers,
457456
lr_schedulers=self.lr_schedulers,
458-
states=states,
457+
states=self.states,
459458
job_config=job_config,
460459
ft_manager=ft_manager,
461460
)
@@ -569,14 +568,13 @@ def __init__(self):
569568
self.assertEqual(mock_save.call_count, 1)
570569
checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt")
571570
saved_data = torch.load(checkpoint_path, weights_only=False)
572-
model_state_dict = saved_data[MODEL]
573571

574572
# Verify that freqs_cis is NOT in the saved state dict
575-
self.assertNotIn("freqs_cis", model_state_dict)
573+
self.assertNotIn("freqs_cis", saved_data)
576574
# Verify that other parameters ARE in the saved state dict
577-
self.assertIn("weight", model_state_dict)
578-
self.assertIn("bias", model_state_dict)
579-
self.assertIn("other_param", model_state_dict)
575+
self.assertIn("weight", saved_data)
576+
self.assertIn("bias", saved_data)
577+
self.assertIn("other_param", saved_data)
580578

581579
manager.close()
582580

torchtitan/components/checkpoint.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,11 +392,15 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
392392
elif self.async_mode == AsyncMode.ASYNC:
393393
GarbageCollection.collect("GC collection invoked by checkpointer.")
394394
self.async_future = dcp.async_save(
395-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
395+
self._flattend_model_states_sd(),
396+
checkpoint_id=checkpoint_id,
397+
process_group=self.pg,
396398
)
397399
GarbageCollection.collect("GC collection invoked by checkpointer.")
398400
else:
399-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
401+
save_with_gc(
402+
self._flattend_model_states_sd(), checkpoint_id=checkpoint_id
403+
)
400404
self._purge_stale_checkpoints()
401405

402406
logger.info(
@@ -559,6 +563,19 @@ def _ft_load(self) -> None:
559563
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
560564
)
561565

566+
def _flattend_model_states_sd(
567+
self, state_dict: dict[str, Any] | None = None
568+
) -> dict[str, Any]:
569+
"""Flatten the model states into a single dictionary.
570+
571+
Note that other states, such as optimizer states, are not flattened.
572+
"""
573+
states = state_dict if state_dict is not None else self.states
574+
sd = {k: v for k, v in states.items() if k != MODEL}
575+
if MODEL in states:
576+
sd.update(states[MODEL].state_dict())
577+
return sd
578+
562579
def _states_to_load(self, model_only: bool) -> dict[str, Any]:
563580
"""Determines which states to load for the given step.
564581
@@ -573,8 +590,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
573590
"""
574591
# For the first step, we will only load the model weights.
575592
if model_only:
576-
sd = self.states[MODEL].state_dict()
577-
return sd
593+
return self.states[MODEL].state_dict()
578594

579595
for exclude_key in self.exclude_from_loading:
580596
if exclude_key not in self.states:
@@ -584,6 +600,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
584600
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
585601
}
586602

603+
states_to_load = self._flattend_model_states_sd(states_to_load)
604+
587605
if self.ft_manager:
588606
states_to_load.pop(DATALOADER)
589607

0 commit comments

Comments
 (0)