Skip to content

Commit 63f853d

Browse files
committed
Merge branch 'dependencyvit' of https://github.com/fffffgggg54/pytorch-image-models into dependencyvit
2 parents e6f8765 + 94e5558 commit 63f853d

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

timm/models/_builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _filter_kwargs(kwargs, names):
261261
kwargs.pop(n, None)
262262

263263

264-
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
264+
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
265265
""" Update the default_cfg and kwargs before passing to model
266266
267267
Args:
@@ -288,6 +288,11 @@ def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
288288
if input_size is not None:
289289
assert len(input_size) == 3
290290
kwargs.setdefault(n, input_size[0])
291+
elif n == 'num_classes':
292+
default_val = pretrained_cfg.get(n, None)
293+
# if default is < 0, don't pass through to model
294+
if default_val is not None and default_val >= 0:
295+
kwargs.setdefault(n, pretrained_cfg[n])
291296
else:
292297
default_val = pretrained_cfg.get(n, None)
293298
if default_val is not None:
@@ -379,7 +384,7 @@ def build_model_with_cfg(
379384
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
380385
pretrained_cfg = pretrained_cfg.to_dict()
381386

382-
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
387+
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
383388

384389
# Setup for feature extraction wrapper done at end of this fn
385390
if kwargs.pop('features_only', False):

train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
1616
"""
1717
import argparse
18+
import json
1819
import logging
1920
import os
2021
import time
@@ -425,7 +426,10 @@ def main():
425426
factory_kwargs = {}
426427
if args.pretrained_path:
427428
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
428-
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
429+
factory_kwargs['pretrained_cfg_overlay'] = dict(
430+
file=args.pretrained_path,
431+
num_classes=-1, # force head adaptation
432+
)
429433

430434
model = create_model(
431435
args.model,
@@ -770,6 +774,7 @@ def main():
770774
_logger.info(
771775
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
772776

777+
results = []
773778
try:
774779
for epoch in range(start_epoch, num_epochs):
775780
if hasattr(dataset_train, 'set_epoch'):
@@ -841,11 +846,20 @@ def main():
841846
# step LR for next epoch
842847
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
843848

849+
results.append({
850+
'epoch': epoch,
851+
'train': train_metrics,
852+
'validation': eval_metrics,
853+
})
854+
844855
except KeyboardInterrupt:
845856
pass
846857

858+
results = {'all': results}
847859
if best_metric is not None:
860+
results['best'] = results['all'][best_epoch]
848861
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
862+
print(f'--result\n{json.dumps(results, indent=4)}')
849863

850864

851865
def train_one_epoch(

0 commit comments

Comments
 (0)