2
2
import logging
3
3
import os
4
4
from copy import deepcopy
5
+ from pathlib import Path
5
6
from typing import Any , Callable , Dict , Optional , Tuple
6
7
from contextlib import nullcontext
7
8
@@ -92,6 +93,7 @@ def load_custom_pretrained(
92
93
model : nn .Module ,
93
94
pretrained_cfg : Optional [Dict ] = None ,
94
95
load_fn : Optional [Callable ] = None ,
96
+ cache_dir : Optional [Union [str , Path ]] = None ,
95
97
):
96
98
r"""Loads a custom (read non .pth) weight file
97
99
@@ -104,9 +106,10 @@ def load_custom_pretrained(
104
106
105
107
Args:
106
108
model: The instantiated model to load weights into
107
- pretrained_cfg (dict) : Default pretrained model cfg
109
+ pretrained_cfg: Default pretrained model cfg
108
110
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
109
- 'laod_pretrained' on the model will be called if it exists
111
+ 'load_pretrained' on the model will be called if it exists
112
+ cache_dir: Override model checkpoint cache dir for this load
110
113
"""
111
114
pretrained_cfg = pretrained_cfg or getattr (model , 'pretrained_cfg' , None )
112
115
if not pretrained_cfg :
@@ -124,6 +127,7 @@ def load_custom_pretrained(
124
127
pretrained_loc ,
125
128
check_hash = _CHECK_HASH ,
126
129
progress = _DOWNLOAD_PROGRESS ,
130
+ cache_dir = cache_dir ,
127
131
)
128
132
129
133
if load_fn is not None :
@@ -141,17 +145,18 @@ def load_pretrained(
141
145
in_chans : int = 3 ,
142
146
filter_fn : Optional [Callable ] = None ,
143
147
strict : bool = True ,
148
+ cache_dir : Optional [Union [str , Path ]] = None ,
144
149
):
145
150
""" Load pretrained checkpoint
146
151
147
152
Args:
148
- model (nn.Module) : PyTorch model module
149
- pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
150
- num_classes (int): num_classes for target model
151
- in_chans (int): in_chans for target model
152
- filter_fn (Optional[Callable]) : state_dict filter fn for load (takes state_dict, model as args)
153
- strict (bool): strict load of checkpoint
154
-
153
+ model: PyTorch module
154
+ pretrained_cfg: Configuration for pretrained weights / target dataset
155
+ num_classes: Number of classes for target model. Will adapt pretrained if different.
156
+ in_chans: Number of input chans for target model. Will adapt pretrained if different.
157
+ filter_fn: state_dict filter fn for load (takes state_dict, model as args)
158
+ strict: Strict load of checkpoint
159
+ cache_dir: Override model checkpoint cache dir for this load
155
160
"""
156
161
pretrained_cfg = pretrained_cfg or getattr (model , 'pretrained_cfg' , None )
157
162
if not pretrained_cfg :
@@ -175,6 +180,7 @@ def load_pretrained(
175
180
pretrained_loc ,
176
181
progress = _DOWNLOAD_PROGRESS ,
177
182
check_hash = _CHECK_HASH ,
183
+ cache_dir = cache_dir ,
178
184
)
179
185
model .load_pretrained (pretrained_loc )
180
186
return
@@ -186,25 +192,27 @@ def load_pretrained(
186
192
progress = _DOWNLOAD_PROGRESS ,
187
193
check_hash = _CHECK_HASH ,
188
194
weights_only = True ,
195
+ model_dir = cache_dir ,
189
196
)
190
197
except TypeError :
191
198
state_dict = load_state_dict_from_url (
192
199
pretrained_loc ,
193
200
map_location = 'cpu' ,
194
201
progress = _DOWNLOAD_PROGRESS ,
195
202
check_hash = _CHECK_HASH ,
203
+ model_dir = cache_dir ,
196
204
)
197
205
elif load_from == 'hf-hub' :
198
206
_logger .info (f'Loading pretrained weights from Hugging Face hub ({ pretrained_loc } )' )
199
207
if isinstance (pretrained_loc , (list , tuple )):
200
208
custom_load = pretrained_cfg .get ('custom_load' , False )
201
209
if isinstance (custom_load , str ) and custom_load == 'hf' :
202
- load_custom_from_hf (* pretrained_loc , model )
210
+ load_custom_from_hf (* pretrained_loc , model , cache_dir = cache_dir )
203
211
return
204
212
else :
205
- state_dict = load_state_dict_from_hf (* pretrained_loc )
213
+ state_dict = load_state_dict_from_hf (* pretrained_loc , cache_dir = cache_dir )
206
214
else :
207
- state_dict = load_state_dict_from_hf (pretrained_loc , weights_only = True )
215
+ state_dict = load_state_dict_from_hf (pretrained_loc , weights_only = True , cache_dir = cache_dir )
208
216
else :
209
217
model_name = pretrained_cfg .get ('architecture' , 'this model' )
210
218
raise RuntimeError (f"No pretrained weights exist for { model_name } . Use `pretrained=False` for random init." )
@@ -321,8 +329,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
321
329
322
330
def resolve_pretrained_cfg (
323
331
variant : str ,
324
- pretrained_cfg = None ,
325
- pretrained_cfg_overlay = None ,
332
+ pretrained_cfg : Optional [ Union [ str , Dict [ str , Any ]]] = None ,
333
+ pretrained_cfg_overlay : Optional [ Dict [ str , Any ]] = None ,
326
334
) -> PretrainedCfg :
327
335
model_with_tag = variant
328
336
pretrained_tag = None
@@ -364,6 +372,7 @@ def build_model_with_cfg(
364
372
feature_cfg : Optional [Dict ] = None ,
365
373
pretrained_strict : bool = True ,
366
374
pretrained_filter_fn : Optional [Callable ] = None ,
375
+ cache_dir : Optional [Union [str , Path ]] = None ,
367
376
kwargs_filter : Optional [Tuple [str ]] = None ,
368
377
** kwargs ,
369
378
):
@@ -376,16 +385,18 @@ def build_model_with_cfg(
376
385
* pruning config / model adaptation
377
386
378
387
Args:
379
- model_cls: model class
380
- variant: model variant name
381
- pretrained: load pretrained weights
382
- pretrained_cfg: model's pretrained weight/task config
383
- model_cfg: model's architecture config
384
- feature_cfg: feature extraction adapter config
385
- pretrained_strict: load pretrained weights strictly
386
- pretrained_filter_fn: filter callable for pretrained weights
387
- kwargs_filter: kwargs to filter before passing to model
388
- **kwargs: model args passed through to model __init__
388
+ model_cls: Model class
389
+ variant: Model variant name
390
+ pretrained: Load the pretrained weights
391
+ pretrained_cfg: Model's pretrained weight/task config
392
+ pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
393
+ model_cfg: Model's architecture config
394
+ feature_cfg: Feature extraction adapter config
395
+ pretrained_strict: Load pretrained weights strictly
396
+ pretrained_filter_fn: Filter callable for pretrained weights
397
+ cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
398
+ kwargs_filter: Kwargs keys to filter (remove) before passing to model
399
+ **kwargs: Model args passed through to model __init__
389
400
"""
390
401
pruned = kwargs .pop ('pruned' , False )
391
402
features = False
@@ -397,8 +408,6 @@ def build_model_with_cfg(
397
408
pretrained_cfg = pretrained_cfg ,
398
409
pretrained_cfg_overlay = pretrained_cfg_overlay
399
410
)
400
-
401
- # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
402
411
pretrained_cfg = pretrained_cfg .to_dict ()
403
412
404
413
_update_default_model_kwargs (pretrained_cfg , kwargs , kwargs_filter )
@@ -437,6 +446,7 @@ def build_model_with_cfg(
437
446
in_chans = kwargs .get ('in_chans' , 3 ),
438
447
filter_fn = pretrained_filter_fn ,
439
448
strict = pretrained_strict ,
449
+ cache_dir = cache_dir ,
440
450
)
441
451
442
452
# Wrap the model in a feature extraction module if enabled
0 commit comments