Skip to content

Commit 532e3b4

Browse files
committed
Reorg of utils into separate modules
1 parent 9ce42d5 commit 532e3b4

File tree

13 files changed

+466
-401
lines changed

13 files changed

+466
-401
lines changed

timm/utils.py

Lines changed: 0 additions & 400 deletions
This file was deleted.

timm/utils/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .checkpoint_saver import CheckpointSaver
2+
from .cuda import ApexScaler, NativeScaler
3+
from .distributed import distribute_bn, reduce_tensor
4+
from .jit import set_jit_legacy
5+
from .log import setup_default_logging, FormatterNoInfo
6+
from .metrics import AverageMeter, accuracy
7+
from .misc import natural_key, add_bool_arg
8+
from .model import unwrap_model, get_state_dict
9+
from .model_ema import ModelEma
10+
from .summary import update_summary, get_outdir

timm/utils/checkpoint_saver.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
""" Checkpoint Saver
2+
3+
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
4+
5+
Hacked together by / Copyright 2020 Ross Wightman
6+
"""
7+
8+
import glob
9+
import operator
10+
import os
11+
import logging
12+
13+
import torch
14+
15+
from .model import unwrap_model, get_state_dict
16+
17+
18+
_logger = logging.getLogger(__name__)
19+
20+
21+
class CheckpointSaver:
22+
def __init__(
23+
self,
24+
model,
25+
optimizer,
26+
args=None,
27+
model_ema=None,
28+
amp_scaler=None,
29+
checkpoint_prefix='checkpoint',
30+
recovery_prefix='recovery',
31+
checkpoint_dir='',
32+
recovery_dir='',
33+
decreasing=False,
34+
max_history=10,
35+
unwrap_fn=unwrap_model):
36+
37+
# objects to save state_dicts of
38+
self.model = model
39+
self.optimizer = optimizer
40+
self.args = args
41+
self.model_ema = model_ema
42+
self.amp_scaler = amp_scaler
43+
44+
# state
45+
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
46+
self.best_epoch = None
47+
self.best_metric = None
48+
self.curr_recovery_file = ''
49+
self.last_recovery_file = ''
50+
51+
# config
52+
self.checkpoint_dir = checkpoint_dir
53+
self.recovery_dir = recovery_dir
54+
self.save_prefix = checkpoint_prefix
55+
self.recovery_prefix = recovery_prefix
56+
self.extension = '.pth.tar'
57+
self.decreasing = decreasing # a lower metric is better if True
58+
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
59+
self.max_history = max_history
60+
self.unwrap_fn = unwrap_fn
61+
assert self.max_history >= 1
62+
63+
def save_checkpoint(self, epoch, metric=None):
64+
assert epoch >= 0
65+
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
66+
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
67+
self._save(tmp_save_path, epoch, metric)
68+
if os.path.exists(last_save_path):
69+
os.unlink(last_save_path) # required for Windows support.
70+
os.rename(tmp_save_path, last_save_path)
71+
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
72+
if (len(self.checkpoint_files) < self.max_history
73+
or metric is None or self.cmp(metric, worst_file[1])):
74+
if len(self.checkpoint_files) >= self.max_history:
75+
self._cleanup_checkpoints(1)
76+
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
77+
save_path = os.path.join(self.checkpoint_dir, filename)
78+
os.link(last_save_path, save_path)
79+
self.checkpoint_files.append((save_path, metric))
80+
self.checkpoint_files = sorted(
81+
self.checkpoint_files, key=lambda x: x[1],
82+
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
83+
84+
checkpoints_str = "Current checkpoints:\n"
85+
for c in self.checkpoint_files:
86+
checkpoints_str += ' {}\n'.format(c)
87+
_logger.info(checkpoints_str)
88+
89+
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
90+
self.best_epoch = epoch
91+
self.best_metric = metric
92+
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
93+
if os.path.exists(best_save_path):
94+
os.unlink(best_save_path)
95+
os.link(last_save_path, best_save_path)
96+
97+
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
98+
99+
def _save(self, save_path, epoch, metric=None):
100+
save_state = {
101+
'epoch': epoch,
102+
'arch': type(self.model).__name__.lower(),
103+
'state_dict': get_state_dict(self.model, self.unwrap_fn),
104+
'optimizer': self.optimizer.state_dict(),
105+
'version': 2, # version < 2 increments epoch before save
106+
}
107+
if self.args is not None:
108+
save_state['arch'] = self.args.model
109+
save_state['args'] = self.args
110+
if self.amp_scaler is not None:
111+
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
112+
if self.model_ema is not None:
113+
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
114+
if metric is not None:
115+
save_state['metric'] = metric
116+
torch.save(save_state, save_path)
117+
118+
def _cleanup_checkpoints(self, trim=0):
119+
trim = min(len(self.checkpoint_files), trim)
120+
delete_index = self.max_history - trim
121+
if delete_index <= 0 or len(self.checkpoint_files) <= delete_index:
122+
return
123+
to_delete = self.checkpoint_files[delete_index:]
124+
for d in to_delete:
125+
try:
126+
_logger.debug("Cleaning checkpoint: {}".format(d))
127+
os.remove(d[0])
128+
except Exception as e:
129+
_logger.error("Exception '{}' while deleting checkpoint".format(e))
130+
self.checkpoint_files = self.checkpoint_files[:delete_index]
131+
132+
def save_recovery(self, epoch, batch_idx=0):
133+
assert epoch >= 0
134+
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
135+
save_path = os.path.join(self.recovery_dir, filename)
136+
self._save(save_path, epoch)
137+
if os.path.exists(self.last_recovery_file):
138+
try:
139+
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
140+
os.remove(self.last_recovery_file)
141+
except Exception as e:
142+
_logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
143+
self.last_recovery_file = self.curr_recovery_file
144+
self.curr_recovery_file = save_path
145+
146+
def find_recovery(self):
147+
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
148+
files = glob.glob(recovery_path + '*' + self.extension)
149+
files = sorted(files)
150+
if len(files):
151+
return files[0]
152+
else:
153+
return ''

timm/utils/cuda.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
""" CUDA / AMP utils
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
import torch
6+
7+
try:
8+
from apex import amp
9+
has_apex = True
10+
except ImportError:
11+
amp = None
12+
has_apex = False
13+
14+
15+
class ApexScaler:
16+
state_dict_key = "amp"
17+
18+
def __call__(self, loss, optimizer):
19+
with amp.scale_loss(loss, optimizer) as scaled_loss:
20+
scaled_loss.backward()
21+
optimizer.step()
22+
23+
def state_dict(self):
24+
if 'state_dict' in amp.__dict__:
25+
return amp.state_dict()
26+
27+
def load_state_dict(self, state_dict):
28+
if 'load_state_dict' in amp.__dict__:
29+
amp.load_state_dict(state_dict)
30+
31+
32+
class NativeScaler:
33+
state_dict_key = "amp_scaler"
34+
35+
def __init__(self):
36+
self._scaler = torch.cuda.amp.GradScaler()
37+
38+
def __call__(self, loss, optimizer):
39+
self._scaler.scale(loss).backward()
40+
self._scaler.step(optimizer)
41+
self._scaler.update()
42+
43+
def state_dict(self):
44+
return self._scaler.state_dict()
45+
46+
def load_state_dict(self, state_dict):
47+
self._scaler.load_state_dict(state_dict)

timm/utils/distributed.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
""" Distributed training/validation utils
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
import torch
6+
from torch import distributed as dist
7+
8+
from .model import unwrap_model
9+
10+
11+
def reduce_tensor(tensor, n):
12+
rt = tensor.clone()
13+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
14+
rt /= n
15+
return rt
16+
17+
18+
def distribute_bn(model, world_size, reduce=False):
19+
# ensure every node has the same running bn stats
20+
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
21+
if ('running_mean' in bn_name) or ('running_var' in bn_name):
22+
if reduce:
23+
# average bn stats across whole group
24+
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
25+
bn_buf /= float(world_size)
26+
else:
27+
# broadcast bn stats from rank 0 to whole group
28+
torch.distributed.broadcast(bn_buf, 0)

timm/utils/jit.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
""" JIT scripting/tracing utils
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
import torch
6+
7+
8+
def set_jit_legacy():
9+
""" Set JIT executor to legacy w/ support for op fusion
10+
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
11+
in the JIT exectutor. These API are not supported so could change.
12+
"""
13+
#
14+
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
15+
torch._C._jit_set_profiling_executor(False)
16+
torch._C._jit_set_profiling_mode(False)
17+
torch._C._jit_override_can_fuse_on_gpu(True)
18+
#torch._C._jit_set_texpr_fuser_enabled(True)

timm/utils/log.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
""" Logging helpers
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
import logging
6+
import logging.handlers
7+
8+
9+
class FormatterNoInfo(logging.Formatter):
10+
def __init__(self, fmt='%(levelname)s: %(message)s'):
11+
logging.Formatter.__init__(self, fmt)
12+
13+
def format(self, record):
14+
if record.levelno == logging.INFO:
15+
return str(record.getMessage())
16+
return logging.Formatter.format(self, record)
17+
18+
19+
def setup_default_logging(default_level=logging.INFO, log_path=''):
20+
console_handler = logging.StreamHandler()
21+
console_handler.setFormatter(FormatterNoInfo())
22+
logging.root.addHandler(console_handler)
23+
logging.root.setLevel(default_level)
24+
if log_path:
25+
file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
26+
file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
27+
file_handler.setFormatter(file_formatter)
28+
logging.root.addHandler(file_handler)

timm/utils/metrics.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
""" Eval metrics and related
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
6+
7+
class AverageMeter:
8+
"""Computes and stores the average and current value"""
9+
def __init__(self):
10+
self.reset()
11+
12+
def reset(self):
13+
self.val = 0
14+
self.avg = 0
15+
self.sum = 0
16+
self.count = 0
17+
18+
def update(self, val, n=1):
19+
self.val = val
20+
self.sum += val * n
21+
self.count += n
22+
self.avg = self.sum / self.count
23+
24+
25+
def accuracy(output, target, topk=(1,)):
26+
"""Computes the accuracy over the k top predictions for the specified values of k"""
27+
maxk = max(topk)
28+
batch_size = target.size(0)
29+
_, pred = output.topk(maxk, 1, True, True)
30+
pred = pred.t()
31+
correct = pred.eq(target.view(1, -1).expand_as(pred))
32+
return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk]

timm/utils/misc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
""" Misc utils
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
import re
6+
7+
8+
def natural_key(string_):
9+
"""See http://www.codinghorror.com/blog/archives/001018.html"""
10+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
11+
12+
13+
def add_bool_arg(parser, name, default=False, help=''):
14+
dest_name = name.replace('-', '_')
15+
group = parser.add_mutually_exclusive_group(required=False)
16+
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
17+
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
18+
parser.set_defaults(**{dest_name: default})

timm/utils/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
""" Model / state_dict utils
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
from .model_ema import ModelEma
6+
7+
8+
def unwrap_model(model):
9+
if isinstance(model, ModelEma):
10+
return unwrap_model(model.ema)
11+
else:
12+
return model.module if hasattr(model, 'module') else model
13+
14+
15+
def get_state_dict(model, unwrap_fn=unwrap_model):
16+
return unwrap_fn(model).state_dict()

0 commit comments

Comments
 (0)