Skip to content

Commit 90a01f4

Browse files
committed
hrnet features_only pretrained weight loading issue. Fix #232.
1 parent 110a7c4 commit 90a01f4

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

tests/test_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def test_model_load_pretrained(model_name, batch_size):
120120
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
121121
create_model(model_name, pretrained=True, in_chans=in_chans)
122122

123+
@pytest.mark.timeout(120)
124+
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
125+
@pytest.mark.parametrize('batch_size', [1])
126+
def test_model_features_pretrained(model_name, batch_size):
127+
"""Create that pretrained weights load when features_only==True."""
128+
create_model(model_name, pretrained=True, features_only=True)
123129

124130
EXCLUDE_JIT_FILTERS = [
125131
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable

timm/models/hrnet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,12 +773,14 @@ def forward(self, x) -> List[torch.tensor]:
773773

774774
def _create_hrnet(variant, pretrained, **model_kwargs):
775775
model_cls = HighResolutionNet
776+
strict = True
776777
if model_kwargs.pop('features_only', False):
777778
model_cls = HighResolutionNetFeatures
779+
strict = False
778780

779781
return build_model_with_cfg(
780782
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
781-
model_cfg=cfg_cls[variant], **model_kwargs)
783+
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
782784

783785

784786
@register_model

0 commit comments

Comments
 (0)