21
21
from typing import Callable , List , Tuple , Union
22
22
23
23
24
- from timm .models import is_model , list_models
24
+ from timm .models import is_model , list_models , get_pretrained_cfg
25
25
26
26
27
27
parser = argparse .ArgumentParser (description = 'Per-model process launcher' )
@@ -98,16 +98,32 @@ def main():
98
98
cmd , cmd_args = cmd_from_args (args )
99
99
100
100
model_cfgs = []
101
- model_names = []
102
101
if args .model_list == 'all' :
103
- # NOTE should make this config, for validation / benchmark runs the focus is 1k models,
104
- # so we filter out 21/22k and some other unusable heads. This will change in the future...
105
- exclude_model_filters = ['*in21k' , '*in22k' , '*dino' , '*_22k' ]
106
102
model_names = list_models (
107
103
pretrained = args .pretrained , # only include models w/ pretrained checkpoints if set
108
- exclude_filters = exclude_model_filters
109
104
)
110
105
model_cfgs = [(n , None ) for n in model_names ]
106
+ elif args .model_list == 'all_in1k' :
107
+ model_names = list_models (pretrained = True )
108
+ model_cfgs = []
109
+ for n in model_names :
110
+ pt_cfg = get_pretrained_cfg (n )
111
+ if getattr (pt_cfg , 'num_classes' , 0 ) == 1000 :
112
+ print (n , pt_cfg .num_classes )
113
+ model_cfgs .append ((n , None ))
114
+ elif args .model_list == 'all_res' :
115
+ model_names = list_models ()
116
+ model_names += [n .split ('.' )[0 ] for n in list_models (pretrained = True )]
117
+ model_cfgs = set ()
118
+ for n in model_names :
119
+ pt_cfg = get_pretrained_cfg (n )
120
+ if pt_cfg is None :
121
+ print (f'Model { n } is missing pretrained cfg, skipping.' )
122
+ continue
123
+ model_cfgs .add ((n , pt_cfg .input_size [- 1 ]))
124
+ if pt_cfg .test_input_size is not None :
125
+ model_cfgs .add ((n , pt_cfg .test_input_size [- 1 ]))
126
+ model_cfgs = [(n , {'img-size' : r }) for n , r in sorted (model_cfgs )]
111
127
elif not is_model (args .model_list ):
112
128
# model name doesn't exist, try as wildcard filter
113
129
model_names = list_models (args .model_list )
@@ -122,7 +138,8 @@ def main():
122
138
results_file = args .results_file or './results.csv'
123
139
results = []
124
140
errors = []
125
- print ('Running script on these models: {}' .format (', ' .join (model_names )))
141
+ model_strings = '\n ' .join ([f'{ x [0 ]} , { x [1 ]} ' for x in model_cfgs ])
142
+ print (f"Running script on these models:\n { model_strings } " )
126
143
if not args .sort_key :
127
144
if 'benchmark' in args .script :
128
145
if any (['train' in a for a in args .script_args ]):
@@ -136,10 +153,14 @@ def main():
136
153
print (f'Script: { args .script } , Args: { args .script_args } , Sort key: { sort_key } ' )
137
154
138
155
try :
139
- for m , _ in model_cfgs :
156
+ for m , ax in model_cfgs :
140
157
if not m :
141
158
continue
142
159
args_str = (cmd , * [str (e ) for e in cmd_args ], '--model' , m )
160
+ if ax is not None :
161
+ extra_args = [(f'--{ k } ' , str (v )) for k , v in ax .items ()]
162
+ extra_args = [i for t in extra_args for i in t ]
163
+ args_str += tuple (extra_args )
143
164
try :
144
165
o = subprocess .check_output (args = args_str ).decode ('utf-8' ).split ('--result' )[- 1 ]
145
166
r = json .loads (o )
@@ -157,7 +178,11 @@ def main():
157
178
if errors :
158
179
print (f'{ len (errors )} models had errors during run.' )
159
180
for e in errors :
160
- print (f"\t { e ['model' ]} ({ e .get ('error' , 'Unknown' )} )" )
181
+ if 'model' in e :
182
+ print (f"\t { e ['model' ]} ({ e .get ('error' , 'Unknown' )} )" )
183
+ else :
184
+ print (e )
185
+
161
186
results = list (filter (lambda x : 'error' not in x , results ))
162
187
163
188
no_sortkey = list (filter (lambda x : sort_key not in x , results ))
0 commit comments