|
16 | 16 | import torch.nn as nn
|
17 | 17 | from torch.utils.data import DataLoader
|
18 | 18 | from torchtitan.components.checkpoint import CheckpointManager, MODEL
|
| 19 | +from torchtitan.config_manager import Checkpoint as CheckpointConfig |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class FakeOptimizersContainer:
|
@@ -81,7 +82,7 @@ def fake_async_save(*args, **kwargs):
|
81 | 82 | class DummyJobConfig:
|
82 | 83 | def __init__(self, job):
|
83 | 84 | self.job = job
|
84 |
| - self.checkpoint = SimpleNamespace( |
| 85 | + self.checkpoint = CheckpointConfig( |
85 | 86 | enable_checkpoint=True,
|
86 | 87 | async_mode="disabled",
|
87 | 88 | folder="",
|
@@ -112,7 +113,7 @@ def setUp(self):
|
112 | 113 | self.data_loader = FakeDataLoader()
|
113 | 114 | self.ft_manager = DummyFTManager()
|
114 | 115 |
|
115 |
| - ckpt_cfg = SimpleNamespace( |
| 116 | + ckpt_cfg = CheckpointConfig( |
116 | 117 | enable_checkpoint=True,
|
117 | 118 | async_mode="DISABLED",
|
118 | 119 | folder="",
|
@@ -325,17 +326,17 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
|
325 | 326 | ft_manager=self.ft_manager,
|
326 | 327 | )
|
327 | 328 | manager.save(curr_step=1)
|
328 |
| - self.assertEqual(mock_save.call_count, 1) |
| 329 | + self.assertEqual(mock_save.call_count, 0) |
329 | 330 | manager.save(curr_step=2)
|
330 |
| - self.assertEqual(mock_save.call_count, 1) |
| 331 | + self.assertEqual(mock_save.call_count, 0) |
331 | 332 | manager.save(curr_step=2, force=True)
|
332 |
| - self.assertEqual(mock_save.call_count, 2) |
| 333 | + self.assertEqual(mock_save.call_count, 1) |
333 | 334 | manager.save(curr_step=3)
|
334 |
| - self.assertEqual(mock_save.call_count, 3) |
| 335 | + self.assertEqual(mock_save.call_count, 2) |
335 | 336 | manager.save(curr_step=4)
|
336 |
| - self.assertEqual(mock_save.call_count, 3) |
| 337 | + self.assertEqual(mock_save.call_count, 2) |
337 | 338 | manager.save(curr_step=4, force=True)
|
338 |
| - self.assertEqual(mock_save.call_count, 4) |
| 339 | + self.assertEqual(mock_save.call_count, 3) |
339 | 340 | manager.close()
|
340 | 341 |
|
341 | 342 | @mock.patch("torch.distributed.get_rank", return_value=0)
|
@@ -471,6 +472,68 @@ def test_ft_async_save_calls_async_wait(
|
471 | 472 | self.assertIsNotNone(manager.async_future)
|
472 | 473 | manager.async_future.result.assert_not_called()
|
473 | 474 |
|
| 475 | + @mock.patch("torch.distributed.get_rank", return_value=0) |
| 476 | + @mock.patch("torchtitan.components.checkpoint.dcp.save") |
| 477 | + def test_enable_first_step_checkpoint(self, mock_save, mock_rank): |
| 478 | + """ |
| 479 | + Test that enable_first_step_checkpoint triggers checkpoint save at step 1. |
| 480 | + """ |
| 481 | + mock_save.side_effect = self.fake_save |
| 482 | + |
| 483 | + # Test with enable_first_step_checkpoint=False (default case) |
| 484 | + cfg = self.job_config.checkpoint |
| 485 | + cfg.interval = 10 # Set interval to 10 so step 1 wouldn't normally trigger save |
| 486 | + cfg.keep_latest_k = 0 # Disable purging to avoid confusion |
| 487 | + |
| 488 | + manager = CheckpointManager( |
| 489 | + dataloader=self.data_loader, |
| 490 | + model_parts=self.model_parts, |
| 491 | + optimizers=self.optimizers, |
| 492 | + lr_schedulers=self.lr_schedulers, |
| 493 | + states=self.states, |
| 494 | + job_config=self.job_config, |
| 495 | + ft_manager=self.ft_manager, |
| 496 | + ) |
| 497 | + |
| 498 | + # Step 1 should not trigger save when enable_first_step_checkpoint=False |
| 499 | + # and not at interval |
| 500 | + manager.save(curr_step=1) |
| 501 | + self.assertEqual(mock_save.call_count, 0) |
| 502 | + |
| 503 | + # Step 10 should trigger save due to interval |
| 504 | + manager.save(curr_step=10) |
| 505 | + self.assertEqual(mock_save.call_count, 1) |
| 506 | + |
| 507 | + manager.close() |
| 508 | + |
| 509 | + # Test with enable_first_step_checkpoint=True |
| 510 | + mock_save.reset_mock() |
| 511 | + cfg.enable_first_step_checkpoint = True |
| 512 | + |
| 513 | + manager2 = CheckpointManager( |
| 514 | + dataloader=self.data_loader, |
| 515 | + model_parts=self.model_parts, |
| 516 | + optimizers=self.optimizers, |
| 517 | + lr_schedulers=self.lr_schedulers, |
| 518 | + states=self.states, |
| 519 | + job_config=self.job_config, |
| 520 | + ft_manager=self.ft_manager, |
| 521 | + ) |
| 522 | + |
| 523 | + # Step 1 should trigger save due to enable_first_step_checkpoint=True |
| 524 | + manager2.save(curr_step=1) |
| 525 | + self.assertEqual(mock_save.call_count, 1) |
| 526 | + |
| 527 | + # Step 2 should not trigger save (not at interval and not forced) |
| 528 | + manager2.save(curr_step=2) |
| 529 | + self.assertEqual(mock_save.call_count, 1) |
| 530 | + |
| 531 | + # Step 10 should trigger save due to interval |
| 532 | + manager2.save(curr_step=10) |
| 533 | + self.assertEqual(mock_save.call_count, 2) |
| 534 | + |
| 535 | + manager2.close() |
| 536 | + |
474 | 537 |
|
475 | 538 | if __name__ == "__main__":
|
476 | 539 | unittest.main()
|
0 commit comments