103
103
help = 'Name of model to train (default: "resnet50")' )
104
104
group .add_argument ('--pretrained' , action = 'store_true' , default = False ,
105
105
help = 'Start with pretrained version of specified network (if avail)' )
106
+ group .add_argument ('--pretrained-path' , default = None , type = str ,
107
+ help = 'Load this checkpoint as if they were the pretrained weights (with adaptation).' )
106
108
group .add_argument ('--initial-checkpoint' , default = '' , type = str , metavar = 'PATH' ,
107
- help = 'Initialize model from this checkpoint (default: none)' )
109
+ help = 'Load this checkpoint into model after initialization (default: none)' )
108
110
group .add_argument ('--resume' , default = '' , type = str , metavar = 'PATH' ,
109
111
help = 'Resume full model and optimizer state from checkpoint (default: none)' )
110
112
group .add_argument ('--no-resume-opt' , action = 'store_true' , default = False ,
@@ -420,6 +422,11 @@ def main():
420
422
elif args .input_size is not None :
421
423
in_chans = args .input_size [0 ]
422
424
425
+ factory_kwargs = {}
426
+ if args .pretrained_path :
427
+ # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
428
+ factory_kwargs ['pretrained_cfg_overlay' ] = dict (file = args .pretrained_path )
429
+
423
430
model = create_model (
424
431
args .model ,
425
432
pretrained = args .pretrained ,
@@ -433,6 +440,7 @@ def main():
433
440
bn_eps = args .bn_eps ,
434
441
scriptable = args .torchscript ,
435
442
checkpoint_path = args .initial_checkpoint ,
443
+ ** factory_kwargs ,
436
444
** args .model_kwargs ,
437
445
)
438
446
if args .head_init_scale is not None :
0 commit comments