Skip to content

Commit f4048f8

Browse files
authored
Make checkpoint fail_fast feature optional (#1310)
While fail_fast checkpointing feature is useful, it can also waste time and storage when the cluster is already verified with TorchTitan. This PR makes fail_fast feature as optional and defaults to False.
1 parent f7084fc commit f4048f8

File tree

3 files changed

+88
-15
lines changed

3 files changed

+88
-15
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch.nn as nn
1717
from torch.utils.data import DataLoader
1818
from torchtitan.components.checkpoint import CheckpointManager, MODEL
19+
from torchtitan.config_manager import Checkpoint as CheckpointConfig
1920

2021

2122
class FakeOptimizersContainer:
@@ -81,7 +82,7 @@ def fake_async_save(*args, **kwargs):
8182
class DummyJobConfig:
8283
def __init__(self, job):
8384
self.job = job
84-
self.checkpoint = SimpleNamespace(
85+
self.checkpoint = CheckpointConfig(
8586
enable_checkpoint=True,
8687
async_mode="disabled",
8788
folder="",
@@ -112,7 +113,7 @@ def setUp(self):
112113
self.data_loader = FakeDataLoader()
113114
self.ft_manager = DummyFTManager()
114115

115-
ckpt_cfg = SimpleNamespace(
116+
ckpt_cfg = CheckpointConfig(
116117
enable_checkpoint=True,
117118
async_mode="DISABLED",
118119
folder="",
@@ -325,17 +326,17 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
325326
ft_manager=self.ft_manager,
326327
)
327328
manager.save(curr_step=1)
328-
self.assertEqual(mock_save.call_count, 1)
329+
self.assertEqual(mock_save.call_count, 0)
329330
manager.save(curr_step=2)
330-
self.assertEqual(mock_save.call_count, 1)
331+
self.assertEqual(mock_save.call_count, 0)
331332
manager.save(curr_step=2, force=True)
332-
self.assertEqual(mock_save.call_count, 2)
333+
self.assertEqual(mock_save.call_count, 1)
333334
manager.save(curr_step=3)
334-
self.assertEqual(mock_save.call_count, 3)
335+
self.assertEqual(mock_save.call_count, 2)
335336
manager.save(curr_step=4)
336-
self.assertEqual(mock_save.call_count, 3)
337+
self.assertEqual(mock_save.call_count, 2)
337338
manager.save(curr_step=4, force=True)
338-
self.assertEqual(mock_save.call_count, 4)
339+
self.assertEqual(mock_save.call_count, 3)
339340
manager.close()
340341

341342
@mock.patch("torch.distributed.get_rank", return_value=0)
@@ -471,6 +472,68 @@ def test_ft_async_save_calls_async_wait(
471472
self.assertIsNotNone(manager.async_future)
472473
manager.async_future.result.assert_not_called()
473474

475+
@mock.patch("torch.distributed.get_rank", return_value=0)
476+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
477+
def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
478+
"""
479+
Test that enable_first_step_checkpoint triggers checkpoint save at step 1.
480+
"""
481+
mock_save.side_effect = self.fake_save
482+
483+
# Test with enable_first_step_checkpoint=False (default case)
484+
cfg = self.job_config.checkpoint
485+
cfg.interval = 10 # Set interval to 10 so step 1 wouldn't normally trigger save
486+
cfg.keep_latest_k = 0 # Disable purging to avoid confusion
487+
488+
manager = CheckpointManager(
489+
dataloader=self.data_loader,
490+
model_parts=self.model_parts,
491+
optimizers=self.optimizers,
492+
lr_schedulers=self.lr_schedulers,
493+
states=self.states,
494+
job_config=self.job_config,
495+
ft_manager=self.ft_manager,
496+
)
497+
498+
# Step 1 should not trigger save when enable_first_step_checkpoint=False
499+
# and not at interval
500+
manager.save(curr_step=1)
501+
self.assertEqual(mock_save.call_count, 0)
502+
503+
# Step 10 should trigger save due to interval
504+
manager.save(curr_step=10)
505+
self.assertEqual(mock_save.call_count, 1)
506+
507+
manager.close()
508+
509+
# Test with enable_first_step_checkpoint=True
510+
mock_save.reset_mock()
511+
cfg.enable_first_step_checkpoint = True
512+
513+
manager2 = CheckpointManager(
514+
dataloader=self.data_loader,
515+
model_parts=self.model_parts,
516+
optimizers=self.optimizers,
517+
lr_schedulers=self.lr_schedulers,
518+
states=self.states,
519+
job_config=self.job_config,
520+
ft_manager=self.ft_manager,
521+
)
522+
523+
# Step 1 should trigger save due to enable_first_step_checkpoint=True
524+
manager2.save(curr_step=1)
525+
self.assertEqual(mock_save.call_count, 1)
526+
527+
# Step 2 should not trigger save (not at interval and not forced)
528+
manager2.save(curr_step=2)
529+
self.assertEqual(mock_save.call_count, 1)
530+
531+
# Step 10 should trigger save due to interval
532+
manager2.save(curr_step=10)
533+
self.assertEqual(mock_save.call_count, 2)
534+
535+
manager2.close()
536+
474537

475538
if __name__ == "__main__":
476539
unittest.main()

torchtitan/components/checkpoint.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,19 @@ def load_state_dict(state_dict):
273273
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
274274

275275
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
276+
277+
# Checkpoint policy related fields.
276278
self.initial_load_path = ckpt_config.initial_load_path
277279
self.initial_load_model_weights_only = (
278280
ckpt_config.initial_load_model_weights_only
279281
)
282+
self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only
283+
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
284+
self.exclude_from_loading = ckpt_config.exclude_from_loading
280285
self.interval = ckpt_config.interval
286+
self.enable_first_step_checkpoint = ckpt_config.enable_first_step_checkpoint
287+
288+
# Async checkpoint related fields.
281289
async_mode = ckpt_config.async_mode.lower()
282290
if async_mode == AsyncMode.ASYNC or self.ft_manager:
283291
self.pg = dist.new_group(backend="gloo")
@@ -297,10 +305,6 @@ def load_state_dict(state_dict):
297305
else:
298306
self.purge_thread = None
299307

300-
self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only
301-
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
302-
self.exclude_from_loading = ckpt_config.exclude_from_loading
303-
304308
self.mp = None
305309
self.async_future = None
306310
if async_mode == AsyncMode.DISABLED:
@@ -616,9 +620,7 @@ def _should_save(self, curr_step: int, force: bool = False) -> bool:
616620
if not self.enable_checkpoint:
617621
return False
618622

619-
# Force saving a checkpoint at step 1 to fail fast if checkpointer is not
620-
# compatible with the cluster.
621-
if curr_step == 1:
623+
if curr_step == 1 and self.enable_first_step_checkpoint:
622624
return True
623625

624626
if force:

torchtitan/config_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,14 @@ class Checkpoint:
459459
This will load the model only, excluding the specified keys.
460460
"""
461461

462+
enable_first_step_checkpoint: bool = False
463+
"""
464+
Enable the checkpoint save at first step. This will save a checkpoint immediately
465+
after the first step to ensure checkpointing functions correctly. This is useful
466+
when running on a new cluster or storage to verify checkpointing without waiting
467+
for many steps or checkpointing too frequently. The default value is False.
468+
"""
469+
462470

463471
@dataclass
464472
class ActivationCheckpoint:

0 commit comments

Comments
 (0)