Skip to content

Commit 25cf2c2

Browse files
committed
Update bulk_runner with improved filtering options for benchmarking / val runs
1 parent dfb8658 commit 25cf2c2

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

bulk_runner.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Callable, List, Tuple, Union
2222

2323

24-
from timm.models import is_model, list_models
24+
from timm.models import is_model, list_models, get_pretrained_cfg
2525

2626

2727
parser = argparse.ArgumentParser(description='Per-model process launcher')
@@ -98,16 +98,32 @@ def main():
9898
cmd, cmd_args = cmd_from_args(args)
9999

100100
model_cfgs = []
101-
model_names = []
102101
if args.model_list == 'all':
103-
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
104-
# so we filter out 21/22k and some other unusable heads. This will change in the future...
105-
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
106102
model_names = list_models(
107103
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
108-
exclude_filters=exclude_model_filters
109104
)
110105
model_cfgs = [(n, None) for n in model_names]
106+
elif args.model_list == 'all_in1k':
107+
model_names = list_models(pretrained=True)
108+
model_cfgs = []
109+
for n in model_names:
110+
pt_cfg = get_pretrained_cfg(n)
111+
if getattr(pt_cfg, 'num_classes', 0) == 1000:
112+
print(n, pt_cfg.num_classes)
113+
model_cfgs.append((n, None))
114+
elif args.model_list == 'all_res':
115+
model_names = list_models()
116+
model_names += [n.split('.')[0] for n in list_models(pretrained=True)]
117+
model_cfgs = set()
118+
for n in model_names:
119+
pt_cfg = get_pretrained_cfg(n)
120+
if pt_cfg is None:
121+
print(f'Model {n} is missing pretrained cfg, skipping.')
122+
continue
123+
model_cfgs.add((n, pt_cfg.input_size[-1]))
124+
if pt_cfg.test_input_size is not None:
125+
model_cfgs.add((n, pt_cfg.test_input_size[-1]))
126+
model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
111127
elif not is_model(args.model_list):
112128
# model name doesn't exist, try as wildcard filter
113129
model_names = list_models(args.model_list)
@@ -122,7 +138,8 @@ def main():
122138
results_file = args.results_file or './results.csv'
123139
results = []
124140
errors = []
125-
print('Running script on these models: {}'.format(', '.join(model_names)))
141+
model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs])
142+
print(f"Running script on these models:\n {model_strings}")
126143
if not args.sort_key:
127144
if 'benchmark' in args.script:
128145
if any(['train' in a for a in args.script_args]):
@@ -136,10 +153,14 @@ def main():
136153
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
137154

138155
try:
139-
for m, _ in model_cfgs:
156+
for m, ax in model_cfgs:
140157
if not m:
141158
continue
142159
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
160+
if ax is not None:
161+
extra_args = [(f'--{k}', str(v)) for k, v in ax.items()]
162+
extra_args = [i for t in extra_args for i in t]
163+
args_str += tuple(extra_args)
143164
try:
144165
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
145166
r = json.loads(o)
@@ -157,7 +178,11 @@ def main():
157178
if errors:
158179
print(f'{len(errors)} models had errors during run.')
159180
for e in errors:
160-
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
181+
if 'model' in e:
182+
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
183+
else:
184+
print(e)
185+
161186
results = list(filter(lambda x: 'error' not in x, results))
162187

163188
no_sortkey = list(filter(lambda x: sort_key not in x, results))

0 commit comments

Comments
 (0)