Skip to content

Commit adae6b6

Browse files
authored
[RFC] Make last_save_model_weights_only default to True (#1336)
This is a BC breaking change but should be the right way to save the last step checkpoint.
1 parent d9cc6b4 commit adae6b6

File tree

7 files changed

+22
-22
lines changed

7 files changed

+22
-22
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,13 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
329329
self.assertEqual(mock_save.call_count, 0)
330330
manager.save(curr_step=2)
331331
self.assertEqual(mock_save.call_count, 0)
332-
manager.save(curr_step=2, force=True)
332+
manager.save(curr_step=2, last_step=True)
333333
self.assertEqual(mock_save.call_count, 1)
334334
manager.save(curr_step=3)
335335
self.assertEqual(mock_save.call_count, 2)
336336
manager.save(curr_step=4)
337337
self.assertEqual(mock_save.call_count, 2)
338-
manager.save(curr_step=4, force=True)
338+
manager.save(curr_step=4, last_step=True)
339339
self.assertEqual(mock_save.call_count, 3)
340340
manager.close()
341341

@@ -358,7 +358,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
358358
job_config=self.job_config,
359359
ft_manager=self.ft_manager,
360360
)
361-
manager1.save(curr_step=1, force=True)
361+
manager1.save(curr_step=1, last_step=True)
362362
path1 = os.path.join(self.test_folder, "step-1")
363363
self.assertTrue(os.path.isdir(path1))
364364
# Phase 2: initial load from step-1
@@ -383,7 +383,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
383383
args1, kwargs1 = mock_load.call_args
384384
self.assertEqual(kwargs1.get("checkpoint_id"), path1)
385385
# Phase 3: save new step under default folder, then load that
386-
manager2.save(curr_step=2, force=True)
386+
manager2.save(curr_step=2, last_step=True)
387387
# Default folder is test_folder, so step-2 under that
388388
step2_dir = os.path.join(self.test_folder, "step-2")
389389
self.assertTrue(os.path.isdir(step2_dir))
@@ -419,12 +419,12 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
419419
)
420420

421421
# First save schedules async
422-
manager.save(curr_step=10, force=False)
422+
manager.save(curr_step=10, last_step=False)
423423
future = manager.async_future
424424
future.result.assert_not_called()
425425

426426
# Second save should wait
427-
manager.save(curr_step=20, force=False)
427+
manager.save(curr_step=20, last_step=False)
428428
future.result.assert_called_once()
429429

430430
# New future created
@@ -462,12 +462,12 @@ def test_ft_async_save_calls_async_wait(
462462

463463
# Initially no future
464464
self.assertIsNone(manager.async_future)
465-
manager.save(curr_step=5, force=False)
465+
manager.save(curr_step=5, last_step=False)
466466
self.assertIsNotNone(manager.async_future)
467467

468468
manager.async_future.result.assert_not_called()
469469
prev_future = manager.async_future
470-
manager.save(curr_step=6, force=False)
470+
manager.save(curr_step=6, last_step=False)
471471
prev_future.result.assert_called_once()
472472
self.assertIsNotNone(manager.async_future)
473473
manager.async_future.result.assert_not_called()

torchtitan/components/checkpoint.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,17 +349,17 @@ def close(self):
349349
self.purge_thread.join()
350350

351351
@torch.no_grad()
352-
def save(self, curr_step: int, force: bool = False) -> None:
352+
def save(self, curr_step: int, last_step: bool = False) -> None:
353353
"""Save the checkpoint for the current step.
354354
355-
This function will save the checkpoint for the current step. If ``force`` is
355+
This function will save the checkpoint for the current step. If ``last_step`` is
356356
true, it will save the checkpoint even if the interval has not been reached.
357357
This only happens when train_state.step == job_config.training.steps, or
358358
for initial seed checkpoint.
359359
360360
Args:
361361
curr_step (int): The current step.
362-
force (bool, optional): Whether to force save the checkpoint. Defaults to False.
362+
last_step (bool, optional): Whether this is the last step of training.
363363
364364
Returns:
365365
None
@@ -368,7 +368,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
368368
if self.ft_manager:
369369
self._ft_save(curr_step)
370370

371-
if not self._should_save(curr_step, force):
371+
if not self._should_save(curr_step, last_step):
372372
return
373373

374374
begin = time.monotonic()
@@ -379,7 +379,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
379379
# This GC is called for async checkpoint as it is useless to do
380380
# GC right after async_save -- the CPU memory is not able to be
381381
# freed until _async_wait()
382-
if force:
382+
if last_step:
383383
self._save_last_step(curr_step)
384384
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
385385
GarbageCollection.collect("GC collection invoked by checkpointer.")
@@ -616,14 +616,14 @@ def _save_last_step(self, curr_step: int) -> None:
616616

617617
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
618618

619-
def _should_save(self, curr_step: int, force: bool = False) -> bool:
619+
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
620620
if not self.enable_checkpoint:
621621
return False
622622

623623
if curr_step == 1 and self.enable_first_step_checkpoint:
624624
return True
625625

626-
if force:
626+
if last_step:
627627
return True
628628

629629
if curr_step % self.interval == 0:

torchtitan/config_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,13 @@ class Checkpoint:
404404
interval: int = 500
405405
"""Checkpointing interval in steps."""
406406

407-
last_save_model_weights_only: bool = False
407+
last_save_model_weights_only: bool = True
408408
"""
409409
When last_save_model_weights_only=True, only model weights will be saved at the end of training,
410410
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
411411
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
412412
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
413-
The default value is false.
413+
The default value is True.
414414
"""
415415

416416
export_dtype: Literal["float16", "bfloat16", "float32"] = "float32"

torchtitan/experiments/flux/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def train_step(
221221
assert (
222222
config.checkpoint.enable_checkpoint
223223
), "Must enable checkpointing when creating a seed checkpoint."
224-
trainer.checkpointer.save(curr_step=0, force=True)
224+
trainer.checkpointer.save(curr_step=0, last_step=True)
225225
logger.info("Created seed checkpoint")
226226
else:
227227
trainer.train()

torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def state_dict(self) -> dict[str, torch.Tensor]:
536536
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
537537
trainer.checkpointer.last_save_model_weights_only = True
538538
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
539-
trainer.checkpointer.save(curr_step=0, force=True)
539+
trainer.checkpointer.save(curr_step=0, last_step=True)
540540
time.sleep(2)
541541
finally:
542542
pass

torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def state_dict(self) -> dict[str, torch.Tensor]:
531531
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
532532
trainer.checkpointer.last_save_model_weights_only = True
533533
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
534-
trainer.checkpointer.save(curr_step=0, force=True)
534+
trainer.checkpointer.save(curr_step=0, last_step=True)
535535
time.sleep(2)
536536
finally:
537537
pass

torchtitan/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def train(self):
494494
logger.warning("Ran out of data; last step was canceled.")
495495
break
496496
self.checkpointer.save(
497-
self.step, force=(self.step == job_config.training.steps)
497+
self.step, last_step=(self.step == job_config.training.steps)
498498
)
499499

500500
# signal the profiler that the next profiling step has started
@@ -547,7 +547,7 @@ def close(self) -> None:
547547
assert (
548548
config.checkpoint.enable_checkpoint
549549
), "Must enable checkpointing when creating a seed checkpoint."
550-
trainer.checkpointer.save(curr_step=0, force=True)
550+
trainer.checkpointer.save(curr_step=0, last_step=True)
551551
logger.info("Created seed checkpoint")
552552
else:
553553
trainer.train()

0 commit comments

Comments
 (0)