Skip to content

Commit cddd7dc

Browse files
authored
Implement initial_load_path for checkpointer (#1236)
Currently, users are required to copy pre-trained checkpoints to a specific folder or point the output checkpoint folder to the pre-trained checkpoint folder. This can be inconvenient and limiting. This pull request addresses this issue by introducing support for loading the first checkpoint from a user-specified path.
1 parent e320a37 commit cddd7dc

18 files changed

+116
-52
lines changed

docs/checkpoint.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ interval = 500
2525

2626

2727
2. SAVE ONLY MODEL WEIGHTS
28-
By setting `model_weights_only` to `True`, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
28+
By setting `last_save_model_weights_only` to `True`, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
2929
```
3030
[checkpoint]
3131
enable_checkpoint = true
32-
model_weights_only = true
32+
last_save_model_weights_only = true
3333
```
3434

3535
3. CHOOSE DESIRED EXPORT PRECISION
3636
The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`.
3737
```
3838
[checkpoint]
3939
enable_checkpoint = true
40-
model_weights_only = true
40+
last_save_model_weights_only = true
4141
export_dtype = "bfloat16"
4242
```
4343

@@ -48,7 +48,7 @@ enable_checkpoint = true
4848
folder = "checkpoint"
4949
interval = 10
5050
load_step = 5
51-
model_weights_only = true
51+
last_save_model_weights_only = true
5252
export_dtype = "bfloat16"
5353
```
5454

tests/integration_tests.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,22 @@ def build_test_list():
122122
[
123123
[
124124
"--checkpoint.enable_checkpoint",
125-
"--checkpoint.model_weights_only",
125+
"--checkpoint.last_save_model_weights_only",
126126
],
127127
],
128128
"Checkpoint Integration Test - Save Model Weights Only fp32",
129-
"model_weights_only_fp32",
129+
"last_save_model_weights_only_fp32",
130130
),
131131
OverrideDefinitions(
132132
[
133133
[
134134
"--checkpoint.enable_checkpoint",
135-
"--checkpoint.model_weights_only",
135+
"--checkpoint.last_save_model_weights_only",
136136
"--checkpoint.export_dtype bfloat16",
137137
],
138138
],
139139
"Checkpoint Integration Test - Save Model Weights Only bf16",
140-
"model_weights_only_bf16",
140+
"last_save_model_weights_only_bf16",
141141
),
142142
OverrideDefinitions(
143143
[

tests/unit_tests/test_checkpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,22 @@ def fake_get_model_state_dict(model, *args, **kwargs):
4444
return model.state_dict()
4545

4646

47+
# TODO: The unittest is not well structured and does not cover enough paths.
48+
# It should be refactored.
49+
50+
4751
@dataclass
4852
class DummyCheckpointConfig:
4953
enable_checkpoint: bool = True
5054
folder: str = "dummy_folder"
5155
interval: int = 10
5256
async_mode: str = "disabled"
5357
keep_latest_k: int = 0
54-
model_weights_only: bool = False
58+
last_save_model_weights_only: bool = False
5559
export_dtype: str = "float32"
5660
exclude_from_loading = []
61+
initial_load_model_weights_only: bool = False
62+
initial_load_path: str = ""
5763

5864

5965
@dataclass

torchtitan/components/checkpoint.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class SaveDone:
8181
pass
8282

8383

84+
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
85+
# temporarily and we don't want to include it in the exported state_dict.
86+
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
87+
excluded_parameters_for_model_only = {"freqs_cis"}
88+
89+
8490
@torch.no_grad()
8591
def save_with_gc(state, checkpoint_id):
8692
dcp.save(state, checkpoint_id=checkpoint_id)
@@ -267,6 +273,10 @@ def load_state_dict(state_dict):
267273
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
268274

269275
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
276+
self.initial_load_path = ckpt_config.initial_load_path
277+
self.initial_load_model_weights_only = (
278+
ckpt_config.initial_load_model_weights_only
279+
)
270280
self.interval = ckpt_config.interval
271281
async_mode = ckpt_config.async_mode.lower()
272282
if async_mode == AsyncMode.ASYNC or self.ft_manager:
@@ -287,7 +297,7 @@ def load_state_dict(state_dict):
287297
else:
288298
self.purge_thread = None
289299

290-
self.model_weights_only = ckpt_config.model_weights_only
300+
self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only
291301
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
292302
self.exclude_from_loading = ckpt_config.exclude_from_loading
293303

@@ -408,21 +418,38 @@ def load(self, step: int = -1) -> bool:
408418
if self.ft_manager:
409419
self._ft_load()
410420

411-
if not self.enable_checkpoint or not os.path.isdir(self.folder):
421+
if not self.enable_checkpoint:
412422
return False
413423

414-
if step == -1:
415-
step = self._find_load_step()
424+
model_only = False
425+
if not os.path.exists(self.folder):
426+
if self.initial_load_path:
427+
checkpoint_id = self.initial_load_path
428+
if not os.path.isdir(checkpoint_id):
429+
raise ValueError(
430+
"initial_load_full_checkpoint is specified but the path is not valid."
431+
)
432+
model_only = self.initial_load_model_weights_only
433+
else:
434+
return False
435+
else:
436+
if self.initial_load_path:
437+
logger.info(
438+
"`initial_load_path` is provided but the checkpoint folder exists. "
439+
"Checkpointer will use the checkpoints from the checkpoint folder."
440+
)
441+
step = self._find_load_step() if step == -1 else step
416442
if step == -1:
417443
return False
444+
model_only = step == 0
445+
checkpoint_id = self._create_checkpoint_id(step)
418446

419-
checkpoint_id = self._create_checkpoint_id(step)
420-
if not os.path.isdir(checkpoint_id):
421-
return False
447+
if not os.path.isdir(checkpoint_id):
448+
return False
422449

423-
logger.info(f"Loading the checkpoint at step {step}.")
450+
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
424451
begin = time.monotonic()
425-
states = self._states_to_load(step)
452+
states = self._states_to_load(model_only)
426453
dcp.load(states, checkpoint_id=checkpoint_id)
427454
GarbageCollection.collect("GC collection for checkpoint loading.")
428455
logger.info(
@@ -521,28 +548,36 @@ def _ft_load(self) -> None:
521548
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
522549
)
523550

524-
def _states_to_load(self, step: int) -> dict[str, Any]:
551+
def _states_to_load(self, model_only: bool) -> dict[str, Any]:
525552
"""Determines which states to load for the given step.
526553
527-
When checkpointer determines which step of the checkpoint to load, this API is
528-
used to determine which states to load based on the step.
554+
This API is used to determine which states to load based on the
555+
configurations.
529556
530557
Args:
531-
step (int): The step to load the checkpoint for.
558+
model_only (bool): Whether to load the model only.
532559
533560
Returns:
534561
Dict[str, Any]: The states to load for the given step.
535562
"""
536563
# For the first step, we will only load the model weights.
537-
states = {MODEL: self.states[MODEL]} if step == 0 else self.states
538-
states_to_load = {
539-
k: v for k, v in states.items() if k not in self.exclude_from_loading
540-
}
564+
if model_only:
565+
sd = self.states[MODEL].state_dict()
566+
for k in excluded_parameters_for_model_only:
567+
sd.pop(k, None)
568+
return sd
569+
541570
for exclude_key in self.exclude_from_loading:
542-
if exclude_key not in states:
571+
if exclude_key not in self.states:
543572
raise ValueError(f"{exclude_key} not found in state_dict.")
573+
574+
states_to_load = {
575+
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
576+
}
577+
544578
if self.ft_manager:
545579
states_to_load.pop(DATALOADER)
580+
546581
return states_to_load
547582

548583
def _save_last_step(self, curr_step: int) -> None:
@@ -551,18 +586,16 @@ def _save_last_step(self, curr_step: int) -> None:
551586
# dtype conversion when we are checkpoint model weights only and the
552587
# current dtype is not the same as the export dtype at the end of the training.
553588

554-
if self.model_weights_only:
589+
if self.last_save_model_weights_only:
555590
# We update self.states to keep the model only.
556591
# After this update, self.states = {
557592
# 'tok_embeddings.weight':...,
558593
# 'layers.0.attention.wq.weight': ...
559594
# }.
560595
self.states = self.states[MODEL].state_dict()
561596

562-
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
563-
# temporarily and we don't want to include it in the exported state_dict.
564-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
565-
self.states.pop("freqs_cis", None)
597+
for k in excluded_parameters_for_model_only:
598+
self.states.pop(k, None)
566599

567600
if self.export_dtype != torch.float32:
568601
self.states = {

torchtitan/config_manager.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,21 +373,47 @@ class Checkpoint:
373373
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
374374
"""
375375

376+
initial_load_path: str | None = None
377+
"""
378+
This option specifies the path to the initial checkpoint to load, which is
379+
particularly useful for resuming training from a previous run with a
380+
different output path or when loading a checkpoint from a pre-trained model.
381+
If the checkpoint folder for the current run is not empty,
382+
located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
383+
This feature allows users to load an initial checkpoint from a different folder and
384+
continue training, saving new checkpoints to the specified folder without affecting
385+
the existing ones.
386+
387+
Note that the path should contain the full path to the checkpoint folder,
388+
including the step number, if any; for example,
389+
"//pre_train/checkpoints/llama3/llama3_8b/step_10000".
390+
"""
391+
392+
initial_load_model_weights_only: bool = True
393+
"""
394+
This option specifies if only the model weights should be loaded during the initial
395+
checkpoint load. The option is only used when `initial_load_path` is specified.
396+
If False, the checkpoint at `initial_load_path` is treated as a standard training
397+
checkpoint, including optimizer and training states.
398+
The default setting for this option is True. Note that you will have to use
399+
`--checkpoint.no_initial_load_model_weights_only` to override the default setting.
400+
"""
401+
376402
interval: int = 500
377403
"""Checkpointing interval in steps."""
378404

379-
model_weights_only: bool = False
405+
last_save_model_weights_only: bool = False
380406
"""
381-
When model_weights_only=True, only model weights will be saved at the end of training.
382-
With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
383-
When model_weights_only=False, the full checkpoint will be saved.
407+
When last_save_model_weights_only=True, only model weights will be saved at the end of training,
408+
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
409+
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
384410
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
385411
The default value is false.
386412
"""
387413

388414
export_dtype: Literal["float16", "bfloat16", "float32"] = "float32"
389415
"""
390-
Converts to the specified precision when training completes and model_weights_only=true.
416+
Converts to the specified precision when training completes and last_save_model_weights_only=true.
391417
"""
392418

393419
create_seed_checkpoint: bool = False

torchtitan/experiments/flux/tests/integration_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ def build_test_list():
5858
[
5959
[
6060
"--checkpoint.enable_checkpoint",
61-
"--checkpoint.model_weights_only",
61+
"--checkpoint.last_save_model_weights_only",
6262
],
6363
],
6464
"Checkpoint Integration Test - Save Model Weights Only fp32",
65-
"model_weights_only_fp32",
65+
"last_save_model_weights_only_fp32",
6666
),
6767
# Parallelism tests.
6868
OverrideDefinitions(

torchtitan/experiments/flux/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ mode = "full"
6868
enable_checkpoint = false
6969
folder = "checkpoint"
7070
interval = 5
71-
model_weights_only = false
71+
last_save_model_weights_only = false
7272
export_dtype = "float32"
7373
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

torchtitan/experiments/flux/train_configs/flux_dev_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,6 @@ mode = "full"
6767
enable_checkpoint = false
6868
folder = "checkpoint"
6969
interval = 1_000
70-
model_weights_only = false
70+
last_save_model_weights_only = false
7171
export_dtype = "float32"
7272
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

torchtitan/experiments/flux/train_configs/flux_schnell_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,6 @@ mode = "full"
6767
enable_checkpoint = false
6868
folder = "checkpoint"
6969
interval = 1_000
70-
model_weights_only = false
70+
last_save_model_weights_only = false
7171
export_dtype = "float32"
7272
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def _reshard_send(
322322
def _reshard_receive(
323323
self, assignment: _Assignment, state_dict: dict[str, torch.Tensor]
324324
) -> dict[str, torch.Tensor]:
325-
326325
flatten_tensor = torch.empty(
327326
sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)),
328327
dtype=assignment.dtypes[0],
@@ -535,7 +534,7 @@ def state_dict(self) -> dict[str, torch.Tensor]:
535534
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
536535
state_dict["freqs_cis"] = None
537536
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
538-
trainer.checkpointer.model_weights_only = True
537+
trainer.checkpointer.last_save_model_weights_only = True
539538
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
540539
trainer.checkpointer.save(curr_step=0, force=True)
541540
time.sleep(2)

0 commit comments

Comments
 (0)