Skip to content

Commit ff4d3e2

Browse files
committed
Merge branch 'main' into fast_load
2 parents 3222af6 + ea23107 commit ff4d3e2

File tree

12 files changed

+170
-71
lines changed

12 files changed

+170
-71
lines changed

timm/data/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
transform=None,
104104
target_transform=None,
105105
max_steps=None,
106+
**kwargs,
106107
):
107108
assert reader is not None
108109
if isinstance(reader, str):
@@ -121,6 +122,7 @@ def __init__(
121122
input_key=input_key,
122123
target_key=target_key,
123124
max_steps=max_steps,
125+
**kwargs,
124126
)
125127
else:
126128
self.reader = reader

timm/data/dataset_factory.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,34 +74,37 @@ def create_dataset(
7474
seed: int = 42,
7575
repeats: int = 0,
7676
input_img_mode: str = 'RGB',
77+
trust_remote_code: bool = False,
7778
**kwargs,
7879
):
7980
""" Dataset factory method
8081
8182
In parentheses after each arg are the type of dataset supported for each arg, one of:
82-
* folder - default, timm folder (or tar) based ImageDataset
83-
* torch - torchvision based datasets
83+
* Folder - default, timm folder (or tar) based ImageDataset
84+
* Torch - torchvision based datasets
8485
* HFDS - Hugging Face Datasets
86+
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
8587
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
8688
* WDS - Webdataset
87-
* all - any of the above
89+
* All - any of the above
8890
8991
Args:
90-
name: dataset name, empty is okay for folder based datasets
91-
root: root folder of dataset (all)
92-
split: dataset split (all)
93-
search_split: search for split specific child fold from root so one can specify
94-
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
95-
class_map: specify class -> index mapping via text file or dict (folder)
96-
load_bytes: load data, return images as undecoded bytes (folder)
97-
download: download dataset if not present and supported (HFDS, TFDS, torch)
98-
is_training: create dataset in train mode, this is different from the split.
99-
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
100-
batch_size: batch size hint for (TFDS, WDS)
101-
seed: seed for iterable datasets (TFDS, WDS)
102-
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
103-
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
104-
**kwargs: other args to pass to dataset
92+
name: Dataset name, empty is okay for folder based datasets
93+
root: Root folder of dataset (All)
94+
split: Dataset split (All)
95+
search_split: Search for split specific child fold from root so one can specify
96+
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
97+
class_map: Specify class -> index mapping via text file or dict (Folder)
98+
load_bytes: Load data, return images as undecoded bytes (Folder)
99+
download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
100+
is_training: Create dataset in train mode, this is different from the split.
101+
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
102+
batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
103+
seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
104+
repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
105+
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
106+
trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
107+
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes
105108
106109
Returns:
107110
Dataset object
@@ -162,6 +165,7 @@ def create_dataset(
162165
split=split,
163166
class_map=class_map,
164167
input_img_mode=input_img_mode,
168+
trust_remote_code=trust_remote_code,
165169
**kwargs,
166170
)
167171
elif name.startswith('hfids/'):
@@ -177,7 +181,8 @@ def create_dataset(
177181
repeats=repeats,
178182
seed=seed,
179183
input_img_mode=input_img_mode,
180-
**kwargs
184+
trust_remote_code=trust_remote_code,
185+
**kwargs,
181186
)
182187
elif name.startswith('tfds/'):
183188
ds = IterableImageDataset(

timm/data/readers/reader_hfds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self.dataset = datasets.load_dataset(
4949
name, # 'name' maps to path arg in hf datasets
5050
split=split,
51-
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
51+
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
5252
trust_remote_code=trust_remote_code
5353
)
5454
# leave decode for caller, plus we want easy access to original path names...

timm/data/readers/reader_hfids.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
target_img_mode: str = '',
4545
shuffle_size: Optional[int] = None,
4646
num_samples: Optional[int] = None,
47+
trust_remote_code: bool = False
4748
):
4849
super().__init__()
4950
self.root = root
@@ -60,7 +61,11 @@ def __init__(
6061
self.target_key = target_key
6162
self.target_img_mode = target_img_mode
6263

63-
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
64+
self.builder = datasets.load_dataset_builder(
65+
name,
66+
cache_dir=root,
67+
trust_remote_code=trust_remote_code,
68+
)
6469
if download:
6570
self.builder.download_and_prepare()
6671

timm/models/_builder.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
from copy import deepcopy
5+
from pathlib import Path
56
from typing import Any, Callable, Dict, Optional, Tuple
67
from contextlib import nullcontext
78

@@ -92,6 +93,7 @@ def load_custom_pretrained(
9293
model: nn.Module,
9394
pretrained_cfg: Optional[Dict] = None,
9495
load_fn: Optional[Callable] = None,
96+
cache_dir: Optional[Union[str, Path]] = None,
9597
):
9698
r"""Loads a custom (read non .pth) weight file
9799
@@ -104,9 +106,10 @@ def load_custom_pretrained(
104106
105107
Args:
106108
model: The instantiated model to load weights into
107-
pretrained_cfg (dict): Default pretrained model cfg
109+
pretrained_cfg: Default pretrained model cfg
108110
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
109-
'laod_pretrained' on the model will be called if it exists
111+
'load_pretrained' on the model will be called if it exists
112+
cache_dir: Override model checkpoint cache dir for this load
110113
"""
111114
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
112115
if not pretrained_cfg:
@@ -124,6 +127,7 @@ def load_custom_pretrained(
124127
pretrained_loc,
125128
check_hash=_CHECK_HASH,
126129
progress=_DOWNLOAD_PROGRESS,
130+
cache_dir=cache_dir,
127131
)
128132

129133
if load_fn is not None:
@@ -141,17 +145,18 @@ def load_pretrained(
141145
in_chans: int = 3,
142146
filter_fn: Optional[Callable] = None,
143147
strict: bool = True,
148+
cache_dir: Optional[Union[str, Path]] = None,
144149
):
145150
""" Load pretrained checkpoint
146151
147152
Args:
148-
model (nn.Module) : PyTorch model module
149-
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
150-
num_classes (int): num_classes for target model
151-
in_chans (int): in_chans for target model
152-
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
153-
strict (bool): strict load of checkpoint
154-
153+
model: PyTorch module
154+
pretrained_cfg: Configuration for pretrained weights / target dataset
155+
num_classes: Number of classes for target model. Will adapt pretrained if different.
156+
in_chans: Number of input chans for target model. Will adapt pretrained if different.
157+
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
158+
strict: Strict load of checkpoint
159+
cache_dir: Override model checkpoint cache dir for this load
155160
"""
156161
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
157162
if not pretrained_cfg:
@@ -175,6 +180,7 @@ def load_pretrained(
175180
pretrained_loc,
176181
progress=_DOWNLOAD_PROGRESS,
177182
check_hash=_CHECK_HASH,
183+
cache_dir=cache_dir,
178184
)
179185
model.load_pretrained(pretrained_loc)
180186
return
@@ -186,25 +192,27 @@ def load_pretrained(
186192
progress=_DOWNLOAD_PROGRESS,
187193
check_hash=_CHECK_HASH,
188194
weights_only=True,
195+
model_dir=cache_dir,
189196
)
190197
except TypeError:
191198
state_dict = load_state_dict_from_url(
192199
pretrained_loc,
193200
map_location='cpu',
194201
progress=_DOWNLOAD_PROGRESS,
195202
check_hash=_CHECK_HASH,
203+
model_dir=cache_dir,
196204
)
197205
elif load_from == 'hf-hub':
198206
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
199207
if isinstance(pretrained_loc, (list, tuple)):
200208
custom_load = pretrained_cfg.get('custom_load', False)
201209
if isinstance(custom_load, str) and custom_load == 'hf':
202-
load_custom_from_hf(*pretrained_loc, model)
210+
load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir)
203211
return
204212
else:
205-
state_dict = load_state_dict_from_hf(*pretrained_loc)
213+
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
206214
else:
207-
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
215+
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
208216
else:
209217
model_name = pretrained_cfg.get('architecture', 'this model')
210218
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
@@ -321,8 +329,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
321329

322330
def resolve_pretrained_cfg(
323331
variant: str,
324-
pretrained_cfg=None,
325-
pretrained_cfg_overlay=None,
332+
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
333+
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
326334
) -> PretrainedCfg:
327335
model_with_tag = variant
328336
pretrained_tag = None
@@ -364,6 +372,7 @@ def build_model_with_cfg(
364372
feature_cfg: Optional[Dict] = None,
365373
pretrained_strict: bool = True,
366374
pretrained_filter_fn: Optional[Callable] = None,
375+
cache_dir: Optional[Union[str, Path]] = None,
367376
kwargs_filter: Optional[Tuple[str]] = None,
368377
**kwargs,
369378
):
@@ -376,16 +385,18 @@ def build_model_with_cfg(
376385
* pruning config / model adaptation
377386
378387
Args:
379-
model_cls: model class
380-
variant: model variant name
381-
pretrained: load pretrained weights
382-
pretrained_cfg: model's pretrained weight/task config
383-
model_cfg: model's architecture config
384-
feature_cfg: feature extraction adapter config
385-
pretrained_strict: load pretrained weights strictly
386-
pretrained_filter_fn: filter callable for pretrained weights
387-
kwargs_filter: kwargs to filter before passing to model
388-
**kwargs: model args passed through to model __init__
388+
model_cls: Model class
389+
variant: Model variant name
390+
pretrained: Load the pretrained weights
391+
pretrained_cfg: Model's pretrained weight/task config
392+
pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
393+
model_cfg: Model's architecture config
394+
feature_cfg: Feature extraction adapter config
395+
pretrained_strict: Load pretrained weights strictly
396+
pretrained_filter_fn: Filter callable for pretrained weights
397+
cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
398+
kwargs_filter: Kwargs keys to filter (remove) before passing to model
399+
**kwargs: Model args passed through to model __init__
389400
"""
390401
pruned = kwargs.pop('pruned', False)
391402
features = False
@@ -397,8 +408,6 @@ def build_model_with_cfg(
397408
pretrained_cfg=pretrained_cfg,
398409
pretrained_cfg_overlay=pretrained_cfg_overlay
399410
)
400-
401-
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
402411
pretrained_cfg = pretrained_cfg.to_dict()
403412

404413
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
@@ -437,6 +446,7 @@ def build_model_with_cfg(
437446
in_chans=kwargs.get('in_chans', 3),
438447
filter_fn=pretrained_filter_fn,
439448
strict=pretrained_strict,
449+
cache_dir=cache_dir,
440450
)
441451

442452
# Wrap the model in a feature extraction module if enabled

timm/models/_factory.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
from typing import Any, Dict, Optional, Union
34
from urllib.parse import urlsplit
45

@@ -40,7 +41,8 @@ def create_model(
4041
pretrained: bool = False,
4142
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
4243
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
43-
checkpoint_path: str = '',
44+
checkpoint_path: Optional[Union[str, Path]] = None,
45+
cache_dir: Optional[Union[str, Path]] = None,
4446
scriptable: Optional[bool] = None,
4547
exportable: Optional[bool] = None,
4648
no_jit: Optional[bool] = None,
@@ -50,17 +52,17 @@ def create_model(
5052
5153
Lookup model's entrypoint function and pass relevant args to create a new model.
5254
53-
<Tip>
55+
Tip:
5456
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
5557
and then the model class __init__(). kwargs values set to None are pruned before passing.
56-
</Tip>
5758
5859
Args:
5960
model_name: Name of model to instantiate.
6061
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
6162
pretrained_cfg: Pass in an external pretrained_cfg for model.
6263
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
6364
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
65+
cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints.
6466
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
6567
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
6668
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
@@ -87,6 +89,10 @@ def create_model(
8789
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
8890
>>> model.num_classes
8991
10
92+
93+
>>> # Create a Dinov2 small model with pretrained weights and save weights in a custom directory.
94+
>>> model = create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, cache_dir="/data/my-models")
95+
>>> # Data will be stored at `/data/my-models/models--timm--vit_small_patch14_dinov2.lvd142m/`
9096
```
9197
"""
9298
# Parameters that aren't supported by all models or are intended to only override model defaults if set
@@ -99,7 +105,10 @@ def create_model(
99105
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
100106
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
101107
# load model weights + pretrained_cfg from Hugging Face hub.
102-
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
108+
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
109+
model_name,
110+
cache_dir=cache_dir,
111+
)
103112
if model_args:
104113
for k, v in model_args.items():
105114
kwargs.setdefault(k, v)
@@ -118,6 +127,7 @@ def create_model(
118127
pretrained=pretrained,
119128
pretrained_cfg=pretrained_cfg,
120129
pretrained_cfg_overlay=pretrained_cfg_overlay,
130+
cache_dir=cache_dir,
121131
**kwargs,
122132
)
123133

timm/models/_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55
import logging
66
import os
7-
from collections import OrderedDict
87
from typing import Any, Callable, Dict, Optional, Union
98

109
import torch

0 commit comments

Comments
 (0)