Skip to content

Commit a5b01ec

Browse files
BenjaminBossanrwightman
authored andcommitted
Add type annotations to _registry.py
Description Add type annotations to _registry.py so that they will pass mypy --strict. Comment I was reading the code and felt that this module would be easier to understand with type annotations. Therefore, I went ahead and added the annotations. The idea with this PR is to start small to see if we can align on _how_ to annotate types. I've seen people in the past disagree on how strictly to annotate the code base, so before spending too much time on this, I wanted to check if you agree, Ross. Most of the added types should be straightforward. Some notes on the non-trivial changes: - I made no assumption about the fn passed to register_model, but maybe the type could be stricter. Are all models nn.Modules? - If I'm not mistaken, the type hint for get_arch_name was incorrect - I had to add a # type: ignore to model.__all__ = ... - I made some minor code changes to list_models to facilitate the typing. I think the changes should not affect the logic of the function. - I removed list from list(sorted(...)) because sorted returns always a list.
1 parent c9406ce commit a5b01ec

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

timm/models/_pretrained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def default_with_tag(self):
9393
return tag, self.cfgs[tag]
9494

9595

96-
def split_model_name_tag(model_name: str, no_tag=''):
96+
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
9797
model_name, *tag_list = model_name.split('.', 1)
9898
tag = tag_list[0] if tag_list else no_tag
9999
return model_name, tag

timm/models/_registry.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,28 @@
88
from collections import defaultdict, deque
99
from copy import deepcopy
1010
from dataclasses import replace
11-
from typing import List, Optional, Union, Tuple
11+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
1212

1313
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
1414

1515
__all__ = [
1616
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
1717
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
1818

19-
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
20-
_model_to_module = {} # mapping of model names to module names
21-
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
22-
_model_has_pretrained = set() # set of model names that have pretrained weight url present
23-
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
24-
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
25-
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
19+
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
20+
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
21+
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
22+
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
23+
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
24+
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
25+
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
2626

2727

28-
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
28+
def get_arch_name(model_name: str) -> str:
2929
return split_model_name_tag(model_name)[0]
3030

3131

32-
def register_model(fn):
32+
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
3333
# lookup containing module
3434
mod = sys.modules[fn.__module__]
3535
module_name_split = fn.__module__.split('.')
@@ -40,7 +40,7 @@ def register_model(fn):
4040
if hasattr(mod, '__all__'):
4141
mod.__all__.append(model_name)
4242
else:
43-
mod.__all__ = [model_name]
43+
mod.__all__ = [model_name] # type: ignore
4444

4545
# add entries to registry dict/sets
4646
_model_entrypoints[model_name] = fn
@@ -87,28 +87,33 @@ def register_model(fn):
8787
return fn
8888

8989

90-
def _natural_key(string_):
90+
def _natural_key(string_: str) -> List[Union[int, str]]:
91+
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
9192
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
9293

9394

9495
def list_models(
9596
filter: Union[str, List[str]] = '',
9697
module: str = '',
97-
pretrained=False,
98-
exclude_filters: str = '',
98+
pretrained: bool = False,
99+
exclude_filters: Union[str, List[str]] = '',
99100
name_matches_cfg: bool = False,
100101
include_tags: Optional[bool] = None,
101-
):
102+
) -> List[str]:
102103
""" Return list of available model names, sorted alphabetically
103104
104105
Args:
105-
filter (str) - Wildcard filter string that works with fnmatch
106-
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
107-
pretrained (bool) - Include only models with valid pretrained weights if True
108-
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
109-
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
110-
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
106+
filter - Wildcard filter string that works with fnmatch
107+
module - Limit model selection to a specific submodule (ie 'vision_transformer')
108+
pretrained - Include only models with valid pretrained weights if True
109+
exclude_filters - Wildcard filters to exclude models after including them with filter
110+
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
111+
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
111112
set to True when pretrained=True else False (default: None)
113+
114+
Returns:
115+
models - The sorted list of models
116+
112117
Example:
113118
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
114119
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@@ -118,7 +123,7 @@ def list_models(
118123
include_tags = pretrained
119124

120125
if module:
121-
all_models = list(_module_to_models[module])
126+
all_models: Iterable[str] = list(_module_to_models[module])
122127
else:
123128
all_models = _model_entrypoints.keys()
124129

@@ -130,36 +135,36 @@ def list_models(
130135
all_models = models_with_tags
131136

132137
if filter:
133-
models = []
138+
models: Set[str] = set()
134139
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
135140
for f in include_filters:
136141
include_models = fnmatch.filter(all_models, f) # include these models
137142
if len(include_models):
138-
models = set(models).union(include_models)
143+
models = models.union(include_models)
139144
else:
140-
models = all_models
145+
models = set(all_models)
141146

142147
if exclude_filters:
143148
if not isinstance(exclude_filters, (tuple, list)):
144149
exclude_filters = [exclude_filters]
145150
for xf in exclude_filters:
146151
exclude_models = fnmatch.filter(models, xf) # exclude these models
147152
if len(exclude_models):
148-
models = set(models).difference(exclude_models)
153+
models = models.difference(exclude_models)
149154

150155
if pretrained:
151156
models = _model_has_pretrained.intersection(models)
152157

153158
if name_matches_cfg:
154159
models = set(_model_pretrained_cfgs).intersection(models)
155160

156-
return list(sorted(models, key=_natural_key))
161+
return sorted(models, key=_natural_key)
157162

158163

159164
def list_pretrained(
160165
filter: Union[str, List[str]] = '',
161166
exclude_filters: str = '',
162-
):
167+
) -> List[str]:
163168
return list_models(
164169
filter=filter,
165170
pretrained=True,
@@ -168,14 +173,14 @@ def list_pretrained(
168173
)
169174

170175

171-
def is_model(model_name):
176+
def is_model(model_name: str) -> bool:
172177
""" Check if a model name exists
173178
"""
174179
arch_name = get_arch_name(model_name)
175180
return arch_name in _model_entrypoints
176181

177182

178-
def model_entrypoint(model_name, module_filter: Optional[str] = None):
183+
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
179184
"""Fetch a model entrypoint for specified model name
180185
"""
181186
arch_name = get_arch_name(model_name)
@@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
184189
return _model_entrypoints[arch_name]
185190

186191

187-
def list_modules():
192+
def list_modules() -> List[str]:
188193
""" Return list of module names that contain models / model entrypoints
189194
"""
190195
modules = _module_to_models.keys()
191-
return list(sorted(modules))
196+
return sorted(modules)
192197

193198

194-
def is_model_in_modules(model_name, module_names):
199+
def is_model_in_modules(
200+
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
201+
) -> bool:
195202
"""Check if a model exists within a subset of modules
203+
196204
Args:
197-
model_name (str) - name of model to check
198-
module_names (tuple, list, set) - names of modules to search in
205+
model_name - name of model to check
206+
module_names - names of modules to search in
199207
"""
200208
arch_name = get_arch_name(model_name)
201209
assert isinstance(module_names, (tuple, list, set))
202210
return any(arch_name in _module_to_models[n] for n in module_names)
203211

204212

205-
def is_model_pretrained(model_name):
213+
def is_model_pretrained(model_name: str) -> bool:
206214
return model_name in _model_has_pretrained
207215

208216

209-
def get_pretrained_cfg(model_name, allow_unregistered=True):
217+
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
210218
if model_name in _model_pretrained_cfgs:
211219
return deepcopy(_model_pretrained_cfgs[model_name])
212220
arch_name, tag = split_model_name_tag(model_name)
@@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
219227
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
220228

221229

222-
def get_pretrained_cfg_value(model_name, cfg_key):
230+
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
223231
""" Get a specific model default_cfg value by key. None if key doesn't exist.
224232
"""
225233
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

0 commit comments

Comments
 (0)