15
15
import torch
16
16
import torch .nn as nn
17
17
from torch .utils .data import DataLoader
18
- from torchtitan .components .checkpoint import CheckpointManager , MODEL
18
+ from torchtitan .components .checkpoint import CheckpointManager
19
19
from torchtitan .config_manager import Checkpoint as CheckpointConfig
20
20
21
21
@@ -105,11 +105,11 @@ def setUp(self):
105
105
106
106
self .model_part = nn .Linear (2 , 2 )
107
107
self .model_parts = [self .model_part ]
108
+ self .states = {"trainer" : torch .tensor ([1.2347 ])}
108
109
# TODO: Use a real OptimizerContainer here so that we can actually verify
109
110
# some optimizer.state_dict() behavior (e.g., the key being the parameter name.)
110
111
self .optimizers = FakeOptimizersContainer ()
111
112
self .lr_schedulers = FakeLRSchedulersContainer ()
112
- self .states = {}
113
113
self .data_loader = FakeDataLoader ()
114
114
self .ft_manager = DummyFTManager ()
115
115
@@ -161,7 +161,7 @@ def fake_load(self, states: dict, checkpoint_id=None):
161
161
if key in states and hasattr (states [key ], "load_state_dict" ):
162
162
states [key ].load_state_dict (val )
163
163
elif key in states and isinstance (states [key ], torch .Tensor ):
164
- states [key ] = val
164
+ states [key ]. copy_ ( val )
165
165
166
166
@mock .patch ("torch.distributed.get_rank" , return_value = 0 )
167
167
@mock .patch ("torchtitan.components.checkpoint.dcp.save" )
@@ -354,7 +354,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
354
354
model_parts = self .model_parts ,
355
355
optimizers = self .optimizers ,
356
356
lr_schedulers = self .lr_schedulers ,
357
- states = { MODEL : self .model_part } ,
357
+ states = self .states ,
358
358
job_config = self .job_config ,
359
359
ft_manager = self .ft_manager ,
360
360
)
@@ -373,7 +373,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
373
373
model_parts = self .model_parts ,
374
374
optimizers = self .optimizers ,
375
375
lr_schedulers = self .lr_schedulers ,
376
- states = { MODEL : self .model_part } ,
376
+ states = self .states ,
377
377
job_config = self .job_config ,
378
378
ft_manager = self .ft_manager ,
379
379
)
@@ -449,13 +449,12 @@ def test_ft_async_save_calls_async_wait(
449
449
job_config .checkpoint .async_mode = "disabled"
450
450
ft_manager = mock .Mock ()
451
451
ft_manager .enabled = True
452
- states = {"trainer" : torch .tensor ([0 ])}
453
452
manager = CheckpointManager (
454
453
dataloader = self .data_loader ,
455
454
model_parts = self .model_parts ,
456
455
optimizers = self .optimizers ,
457
456
lr_schedulers = self .lr_schedulers ,
458
- states = states ,
457
+ states = self . states ,
459
458
job_config = job_config ,
460
459
ft_manager = ft_manager ,
461
460
)
@@ -569,14 +568,13 @@ def __init__(self):
569
568
self .assertEqual (mock_save .call_count , 1 )
570
569
checkpoint_path = os .path .join (self .test_folder , "step-1" , "state_dict.pt" )
571
570
saved_data = torch .load (checkpoint_path , weights_only = False )
572
- model_state_dict = saved_data [MODEL ]
573
571
574
572
# Verify that freqs_cis is NOT in the saved state dict
575
- self .assertNotIn ("freqs_cis" , model_state_dict )
573
+ self .assertNotIn ("freqs_cis" , saved_data )
576
574
# Verify that other parameters ARE in the saved state dict
577
- self .assertIn ("weight" , model_state_dict )
578
- self .assertIn ("bias" , model_state_dict )
579
- self .assertIn ("other_param" , model_state_dict )
575
+ self .assertIn ("weight" , saved_data )
576
+ self .assertIn ("bias" , saved_data )
577
+ self .assertIn ("other_param" , saved_data )
580
578
581
579
manager .close ()
582
580
0 commit comments