Skip to content

Commit 9c297ec

Browse files
committed
Cleanup Apex vs native AMP scaler state save/load. Cleanup CheckpointSaver a bit.
1 parent 80c9d9c commit 9c297ec

File tree

4 files changed

+111
-75
lines changed

4 files changed

+111
-75
lines changed

timm/models/helpers.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
4848
model.load_state_dict(state_dict, strict=strict)
4949

5050

51-
def resume_checkpoint(model, checkpoint_path):
52-
other_state = {}
51+
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
5352
resume_epoch = None
5453
if os.path.isfile(checkpoint_path):
5554
checkpoint = torch.load(checkpoint_path, map_location='cpu')
5655
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
56+
if log_info:
57+
_logger.info('Restoring model state from checkpoint...')
5758
new_state_dict = OrderedDict()
5859
for k, v in checkpoint['state_dict'].items():
5960
name = k[7:] if k.startswith('module') else k
6061
new_state_dict[name] = v
6162
model.load_state_dict(new_state_dict)
62-
if 'optimizer' in checkpoint:
63-
other_state['optimizer'] = checkpoint['optimizer']
64-
if 'amp' in checkpoint:
65-
other_state['amp'] = checkpoint['amp']
63+
64+
if optimizer is not None and 'optimizer' in checkpoint:
65+
if log_info:
66+
_logger.info('Restoring optimizer state from checkpoint...')
67+
optimizer.load_state_dict(checkpoint['optimizer'])
68+
69+
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
70+
if log_info:
71+
_logger.info('Restoring AMP loss scaler state from checkpoint...')
72+
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
73+
6674
if 'epoch' in checkpoint:
6775
resume_epoch = checkpoint['epoch']
6876
if 'version' in checkpoint and checkpoint['version'] > 1:
6977
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
70-
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
78+
79+
if log_info:
80+
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
7181
else:
7282
model.load_state_dict(checkpoint)
73-
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
74-
return other_state, resume_epoch
83+
if log_info:
84+
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
85+
return resume_epoch
7586
else:
7687
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
7788
raise FileNotFoundError()

timm/utils.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,67 @@ def unwrap_model(model):
3737
return model.module if hasattr(model, 'module') else model
3838

3939

40-
def get_state_dict(model):
41-
return unwrap_model(model).state_dict()
40+
def get_state_dict(model, unwrap_fn=unwrap_model):
41+
return unwrap_fn(model).state_dict()
42+
43+
44+
class ApexScaler:
45+
state_dict_key = "amp"
46+
47+
def __call__(self, loss, optimizer):
48+
with amp.scale_loss(loss, optimizer) as scaled_loss:
49+
scaled_loss.backward()
50+
optimizer.step()
51+
52+
def state_dict(self):
53+
if 'state_dict' in amp.__dict__:
54+
return amp.state_dict()
55+
56+
def load_state_dict(self, state_dict):
57+
if 'load_state_dict' in amp.__dict__:
58+
amp.load_state_dict(state_dict)
59+
60+
61+
class NativeScaler:
62+
state_dict_key = "amp_scaler"
63+
64+
def __init__(self):
65+
self._scaler = torch.cuda.amp.GradScaler()
66+
67+
def __call__(self, loss, optimizer):
68+
self._scaler.scale(loss).backward()
69+
self._scaler.step(optimizer)
70+
self._scaler.update()
71+
72+
def state_dict(self):
73+
return self._scaler.state_dict()
74+
75+
def load_state_dict(self, state_dict):
76+
self._scaler.load_state_dict(state_dict)
4277

4378

4479
class CheckpointSaver:
4580
def __init__(
4681
self,
82+
model,
83+
optimizer,
84+
args=None,
85+
model_ema=None,
86+
amp_scaler=None,
4787
checkpoint_prefix='checkpoint',
4888
recovery_prefix='recovery',
4989
checkpoint_dir='',
5090
recovery_dir='',
5191
decreasing=False,
5292
max_history=10,
53-
save_amp=False):
93+
unwrap_fn=unwrap_model):
94+
95+
# objects to save state_dicts of
96+
self.model = model
97+
self.optimizer = optimizer
98+
self.args = args
99+
self.model_ema = model_ema
100+
self.amp_scaler = amp_scaler
54101

55102
# state
56103
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
@@ -68,14 +115,14 @@ def __init__(
68115
self.decreasing = decreasing # a lower metric is better if True
69116
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
70117
self.max_history = max_history
71-
self.save_apex_amp = save_amp # save APEX amp state
118+
self.unwrap_fn = unwrap_fn
72119
assert self.max_history >= 1
73120

74-
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
121+
def save_checkpoint(self, epoch, metric=None):
75122
assert epoch >= 0
76123
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
77124
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
78-
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric)
125+
self._save(tmp_save_path, epoch, metric)
79126
if os.path.exists(last_save_path):
80127
os.unlink(last_save_path) # required for Windows support.
81128
os.rename(tmp_save_path, last_save_path)
@@ -107,19 +154,21 @@ def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=
107154

108155
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
109156

110-
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
157+
def _save(self, save_path, epoch, metric=None):
111158
save_state = {
112159
'epoch': epoch,
113-
'arch': args.model,
114-
'state_dict': get_state_dict(model),
115-
'optimizer': optimizer.state_dict(),
116-
'args': args,
160+
'arch': type(self.model).__name__.lower(),
161+
'state_dict': get_state_dict(self.model, self.unwrap_fn),
162+
'optimizer': self.optimizer.state_dict(),
117163
'version': 2, # version < 2 increments epoch before save
118164
}
119-
if self.save_apex_amp and 'state_dict' in amp.__dict__:
120-
save_state['amp'] = amp.state_dict()
121-
if model_ema is not None:
122-
save_state['state_dict_ema'] = get_state_dict(model_ema)
165+
if self.args is not None:
166+
save_state['arch'] = self.args.model
167+
save_state['args'] = self.args
168+
if self.amp_scaler is not None:
169+
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
170+
if self.model_ema is not None:
171+
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
123172
if metric is not None:
124173
save_state['metric'] = metric
125174
torch.save(save_state, save_path)
@@ -138,11 +187,11 @@ def _cleanup_checkpoints(self, trim=0):
138187
_logger.error("Exception '{}' while deleting checkpoint".format(e))
139188
self.checkpoint_files = self.checkpoint_files[:delete_index]
140189

141-
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
190+
def save_recovery(self, epoch, batch_idx=0):
142191
assert epoch >= 0
143192
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
144193
save_path = os.path.join(self.recovery_dir, filename)
145-
self._save(save_path, model, optimizer, args, epoch, model_ema)
194+
self._save(save_path, epoch)
146195
if os.path.exists(self.last_recovery_file):
147196
try:
148197
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
@@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''):
336385
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
337386
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
338387
parser.set_defaults(**{dest_name: default})
388+
389+
390+
def set_jit_legacy():
391+
""" Set JIT executor to legacy w/ support for op fusion
392+
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
393+
in the JIT exectutor. These API are not supported so could change.
394+
"""
395+
#
396+
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
397+
torch._C._jit_set_profiling_executor(False)
398+
torch._C._jit_set_profiling_mode(False)
399+
torch._C._jit_override_can_fuse_on_gpu(True)
400+
#torch._C._jit_set_texpr_fuser_enabled(True)

train.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from datetime import datetime
2121
from contextlib import suppress
2222

23-
import torch
2423
import torch.nn as nn
2524
import torchvision.utils
2625
from torch.nn.parallel import DistributedDataParallel as NativeDDP
@@ -31,6 +30,7 @@
3130
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
3231
from timm.optim import create_optimizer
3332
from timm.scheduler import create_scheduler
33+
from timm.utils import ApexScaler, NativeScaler
3434

3535
try:
3636
from apex import amp
@@ -264,23 +264,6 @@ def _parse_args():
264264
return args, args_text
265265

266266

267-
class ApexScaler:
268-
def __call__(self, loss, optimizer):
269-
with amp.scale_loss(loss, optimizer) as scaled_loss:
270-
scaled_loss.backward()
271-
optimizer.step()
272-
273-
274-
class NativeScaler:
275-
def __init__(self):
276-
self._scaler = torch.cuda.amp.GradScaler()
277-
278-
def __call__(self, loss, optimizer):
279-
self._scaler.scale(loss).backward()
280-
self._scaler.step(optimizer)
281-
self._scaler.update()
282-
283-
284267
def main():
285268
setup_default_logging()
286269
args, args_text = _parse_args()
@@ -389,20 +372,13 @@ def main():
389372
_logger.info('AMP not enabled. Training in float32.')
390373

391374
# optionally resume from a checkpoint
392-
resume_state = {}
393375
resume_epoch = None
394376
if args.resume:
395-
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
396-
if resume_state and not args.no_resume_opt:
397-
if 'optimizer' in resume_state:
398-
if args.local_rank == 0:
399-
_logger.info('Restoring optimizer state from checkpoint')
400-
optimizer.load_state_dict(resume_state['optimizer'])
401-
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
402-
if args.local_rank == 0:
403-
_logger.info('Restoring NVIDIA AMP state from checkpoint')
404-
amp.load_state_dict(resume_state['amp'])
405-
del resume_state
377+
resume_epoch = resume_checkpoint(
378+
model, args.resume,
379+
optimizer=None if args.no_resume_opt else optimizer,
380+
loss_scaler=None if args.no_resume_opt else loss_scaler,
381+
log_info=args.local_rank == 0)
406382

407383
model_ema = None
408384
if args.model_ema:
@@ -555,7 +531,9 @@ def main():
555531
])
556532
output_dir = get_outdir(output_base, 'train', exp_name)
557533
decreasing = True if eval_metric == 'loss' else False
558-
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex')
534+
saver = CheckpointSaver(
535+
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
536+
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
559537
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
560538
f.write(args_text)
561539

@@ -594,8 +572,7 @@ def main():
594572
if saver is not None:
595573
# save proper checkpoint with eval metric
596574
save_metric = eval_metrics[eval_metric]
597-
best_metric, best_epoch = saver.save_checkpoint(
598-
model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric)
575+
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
599576

600577
except KeyboardInterrupt:
601578
pass
@@ -688,8 +665,7 @@ def train_epoch(
688665

689666
if saver is not None and args.recovery_interval and (
690667
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
691-
692-
saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
668+
saver.save_recovery(epoch, batch_idx=batch_idx)
693669

694670
if lr_scheduler is not None:
695671
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

validate.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
2323
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
24-
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
24+
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
2525

2626
has_apex = False
2727
try:
@@ -102,19 +102,6 @@
102102
help='Valid label indices txt file for validation of partial label space')
103103

104104

105-
def set_jit_legacy():
106-
""" Set JIT executor to legacy w/ support for op fusion
107-
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
108-
in the JIT exectutor. These API are not supported so could change.
109-
"""
110-
#
111-
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
112-
torch._C._jit_set_profiling_executor(False)
113-
torch._C._jit_set_profiling_mode(False)
114-
torch._C._jit_override_can_fuse_on_gpu(True)
115-
#torch._C._jit_set_texpr_fuser_enabled(True)
116-
117-
118105
def validate(args):
119106
# might as well try to validate something
120107
args.pretrained = args.pretrained or not args.checkpoint

0 commit comments

Comments
 (0)