Skip to content

Commit 60b170b

Browse files
committed
Add --pretrained-path arg to train script to allow passing local checkpoint as pretrained. Add missing/unexpected keys log.
1 parent 17a47c0 commit 60b170b

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

timm/models/_builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,15 @@ def load_pretrained(
234234
classifier_bias = state_dict[classifier_name + '.bias']
235235
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
236236

237-
model.load_state_dict(state_dict, strict=strict)
237+
load_result = model.load_state_dict(state_dict, strict=strict)
238+
if load_result.missing_keys:
239+
_logger.info(
240+
f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
241+
f' This is expected if model is being adapted.')
242+
if load_result.unexpected_keys:
243+
_logger.warning(
244+
f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
245+
f' This may be expected if model is being adapted.')
238246

239247

240248
def pretrained_cfg_for_features(pretrained_cfg):

train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@
103103
help='Name of model to train (default: "resnet50")')
104104
group.add_argument('--pretrained', action='store_true', default=False,
105105
help='Start with pretrained version of specified network (if avail)')
106+
group.add_argument('--pretrained-path', default=None, type=str,
107+
help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
106108
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
107-
help='Initialize model from this checkpoint (default: none)')
109+
help='Load this checkpoint into model after initialization (default: none)')
108110
group.add_argument('--resume', default='', type=str, metavar='PATH',
109111
help='Resume full model and optimizer state from checkpoint (default: none)')
110112
group.add_argument('--no-resume-opt', action='store_true', default=False,
@@ -420,6 +422,11 @@ def main():
420422
elif args.input_size is not None:
421423
in_chans = args.input_size[0]
422424

425+
factory_kwargs = {}
426+
if args.pretrained_path:
427+
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
428+
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
429+
423430
model = create_model(
424431
args.model,
425432
pretrained=args.pretrained,
@@ -433,6 +440,7 @@ def main():
433440
bn_eps=args.bn_eps,
434441
scriptable=args.torchscript,
435442
checkpoint_path=args.initial_checkpoint,
443+
**factory_kwargs,
436444
**args.model_kwargs,
437445
)
438446
if args.head_init_scale is not None:

0 commit comments

Comments
 (0)