8
8
from collections import defaultdict , deque
9
9
from copy import deepcopy
10
10
from dataclasses import replace
11
- from typing import List , Optional , Union , Tuple
11
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Sequence , Union , Tuple
12
12
13
13
from ._pretrained import PretrainedCfg , DefaultCfg , split_model_name_tag
14
14
15
15
__all__ = [
16
16
'list_models' , 'list_pretrained' , 'is_model' , 'model_entrypoint' , 'list_modules' , 'is_model_in_modules' ,
17
17
'get_pretrained_cfg_value' , 'is_model_pretrained' , 'get_arch_name' ]
18
18
19
- _module_to_models = defaultdict (set ) # dict of sets to check membership of model in module
20
- _model_to_module = {} # mapping of model names to module names
21
- _model_entrypoints = {} # mapping of model names to architecture entrypoint fns
22
- _model_has_pretrained = set () # set of model names that have pretrained weight url present
23
- _model_default_cfgs = dict () # central repo for model arch -> default cfg objects
24
- _model_pretrained_cfgs = dict () # central repo for model arch.tag -> pretrained cfgs
25
- _model_with_tags = defaultdict (list ) # shortcut to map each model arch to all model + tag names
19
+ _module_to_models : Dict [ str , Set [ str ]] = defaultdict (set ) # dict of sets to check membership of model in module
20
+ _model_to_module : Dict [ str , str ] = {} # mapping of model names to module names
21
+ _model_entrypoints : Dict [ str , Callable [..., Any ]] = {} # mapping of model names to architecture entrypoint fns
22
+ _model_has_pretrained : Set [ str ] = set () # set of model names that have pretrained weight url present
23
+ _model_default_cfgs : Dict [ str , PretrainedCfg ] = {} # central repo for model arch -> default cfg objects
24
+ _model_pretrained_cfgs : Dict [ str , PretrainedCfg ] = {} # central repo for model arch.tag -> pretrained cfgs
25
+ _model_with_tags : Dict [ str , List [ str ]] = defaultdict (list ) # shortcut to map each model arch to all model + tag names
26
26
27
27
28
- def get_arch_name (model_name : str ) -> Tuple [ str , Optional [ str ]] :
28
+ def get_arch_name (model_name : str ) -> str :
29
29
return split_model_name_tag (model_name )[0 ]
30
30
31
31
32
- def register_model (fn ) :
32
+ def register_model (fn : Callable [..., Any ]) -> Callable [..., Any ] :
33
33
# lookup containing module
34
34
mod = sys .modules [fn .__module__ ]
35
35
module_name_split = fn .__module__ .split ('.' )
@@ -40,7 +40,7 @@ def register_model(fn):
40
40
if hasattr (mod , '__all__' ):
41
41
mod .__all__ .append (model_name )
42
42
else :
43
- mod .__all__ = [model_name ]
43
+ mod .__all__ = [model_name ] # type: ignore
44
44
45
45
# add entries to registry dict/sets
46
46
_model_entrypoints [model_name ] = fn
@@ -87,28 +87,33 @@ def register_model(fn):
87
87
return fn
88
88
89
89
90
- def _natural_key (string_ ):
90
+ def _natural_key (string_ : str ) -> List [Union [int , str ]]:
91
+ """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
91
92
return [int (s ) if s .isdigit () else s for s in re .split (r'(\d+)' , string_ .lower ())]
92
93
93
94
94
95
def list_models (
95
96
filter : Union [str , List [str ]] = '' ,
96
97
module : str = '' ,
97
- pretrained = False ,
98
- exclude_filters : str = '' ,
98
+ pretrained : bool = False ,
99
+ exclude_filters : Union [ str , List [ str ]] = '' ,
99
100
name_matches_cfg : bool = False ,
100
101
include_tags : Optional [bool ] = None ,
101
- ):
102
+ ) -> List [ str ] :
102
103
""" Return list of available model names, sorted alphabetically
103
104
104
105
Args:
105
- filter (str) - Wildcard filter string that works with fnmatch
106
- module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
107
- pretrained (bool) - Include only models with valid pretrained weights if True
108
- exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
109
- name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
110
- include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
106
+ filter - Wildcard filter string that works with fnmatch
107
+ module - Limit model selection to a specific submodule (ie 'vision_transformer')
108
+ pretrained - Include only models with valid pretrained weights if True
109
+ exclude_filters - Wildcard filters to exclude models after including them with filter
110
+ name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
111
+ include_tags - Include pretrained tags in model names (model.tag). If None, defaults
111
112
set to True when pretrained=True else False (default: None)
113
+
114
+ Returns:
115
+ models - The sorted list of models
116
+
112
117
Example:
113
118
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
114
119
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@@ -118,7 +123,7 @@ def list_models(
118
123
include_tags = pretrained
119
124
120
125
if module :
121
- all_models = list (_module_to_models [module ])
126
+ all_models : Iterable [ str ] = list (_module_to_models [module ])
122
127
else :
123
128
all_models = _model_entrypoints .keys ()
124
129
@@ -130,36 +135,36 @@ def list_models(
130
135
all_models = models_with_tags
131
136
132
137
if filter :
133
- models = []
138
+ models : Set [ str ] = set ()
134
139
include_filters = filter if isinstance (filter , (tuple , list )) else [filter ]
135
140
for f in include_filters :
136
141
include_models = fnmatch .filter (all_models , f ) # include these models
137
142
if len (include_models ):
138
- models = set ( models ) .union (include_models )
143
+ models = models .union (include_models )
139
144
else :
140
- models = all_models
145
+ models = set ( all_models )
141
146
142
147
if exclude_filters :
143
148
if not isinstance (exclude_filters , (tuple , list )):
144
149
exclude_filters = [exclude_filters ]
145
150
for xf in exclude_filters :
146
151
exclude_models = fnmatch .filter (models , xf ) # exclude these models
147
152
if len (exclude_models ):
148
- models = set ( models ) .difference (exclude_models )
153
+ models = models .difference (exclude_models )
149
154
150
155
if pretrained :
151
156
models = _model_has_pretrained .intersection (models )
152
157
153
158
if name_matches_cfg :
154
159
models = set (_model_pretrained_cfgs ).intersection (models )
155
160
156
- return list ( sorted (models , key = _natural_key ) )
161
+ return sorted (models , key = _natural_key )
157
162
158
163
159
164
def list_pretrained (
160
165
filter : Union [str , List [str ]] = '' ,
161
166
exclude_filters : str = '' ,
162
- ):
167
+ ) -> List [ str ] :
163
168
return list_models (
164
169
filter = filter ,
165
170
pretrained = True ,
@@ -168,14 +173,14 @@ def list_pretrained(
168
173
)
169
174
170
175
171
- def is_model (model_name ) :
176
+ def is_model (model_name : str ) -> bool :
172
177
""" Check if a model name exists
173
178
"""
174
179
arch_name = get_arch_name (model_name )
175
180
return arch_name in _model_entrypoints
176
181
177
182
178
- def model_entrypoint (model_name , module_filter : Optional [str ] = None ):
183
+ def model_entrypoint (model_name : str , module_filter : Optional [str ] = None ) -> Callable [..., Any ] :
179
184
"""Fetch a model entrypoint for specified model name
180
185
"""
181
186
arch_name = get_arch_name (model_name )
@@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
184
189
return _model_entrypoints [arch_name ]
185
190
186
191
187
- def list_modules ():
192
+ def list_modules () -> List [ str ] :
188
193
""" Return list of module names that contain models / model entrypoints
189
194
"""
190
195
modules = _module_to_models .keys ()
191
- return list ( sorted (modules ) )
196
+ return sorted (modules )
192
197
193
198
194
- def is_model_in_modules (model_name , module_names ):
199
+ def is_model_in_modules (
200
+ model_name : str , module_names : Union [Tuple [str , ...], List [str ], Set [str ]]
201
+ ) -> bool :
195
202
"""Check if a model exists within a subset of modules
203
+
196
204
Args:
197
- model_name (str) - name of model to check
198
- module_names (tuple, list, set) - names of modules to search in
205
+ model_name - name of model to check
206
+ module_names - names of modules to search in
199
207
"""
200
208
arch_name = get_arch_name (model_name )
201
209
assert isinstance (module_names , (tuple , list , set ))
202
210
return any (arch_name in _module_to_models [n ] for n in module_names )
203
211
204
212
205
- def is_model_pretrained (model_name ) :
213
+ def is_model_pretrained (model_name : str ) -> bool :
206
214
return model_name in _model_has_pretrained
207
215
208
216
209
- def get_pretrained_cfg (model_name , allow_unregistered = True ):
217
+ def get_pretrained_cfg (model_name : str , allow_unregistered : bool = True ) -> Optional [ PretrainedCfg ] :
210
218
if model_name in _model_pretrained_cfgs :
211
219
return deepcopy (_model_pretrained_cfgs [model_name ])
212
220
arch_name , tag = split_model_name_tag (model_name )
@@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
219
227
raise RuntimeError (f'Model architecture ({ arch_name } ) has no pretrained cfg registered.' )
220
228
221
229
222
- def get_pretrained_cfg_value (model_name , cfg_key ) :
230
+ def get_pretrained_cfg_value (model_name : str , cfg_key : str ) -> Optional [ Any ] :
223
231
""" Get a specific model default_cfg value by key. None if key doesn't exist.
224
232
"""
225
233
cfg = get_pretrained_cfg (model_name , allow_unregistered = False )
0 commit comments