Skip to content

Commit b1b6e7c

Browse files
committed
Fix a few more issues related to #216 w/ TResNet (space2depth) and FP16 weights in wide resnets. Also don't completely dump pretrained weights in in_chans != 1 or 3 cases.
1 parent 512b2dd commit b1b6e7c

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

tests/test_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ def test_model_default_cfgs(model_name, batch_size):
115115
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
116116
@pytest.mark.parametrize('batch_size', [1])
117117
def test_model_load_pretrained(model_name, batch_size):
118-
"""Run a single forward pass with each model"""
119-
create_model(model_name, pretrained=True)
118+
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
119+
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
120+
create_model(model_name, pretrained=True, in_chans=in_chans)
120121

121122

122123
EXCLUDE_JIT_FILTERS = [

timm/models/helpers.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
""" Model creation / weight loading / state_dict helpers
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
15
import logging
26
import os
7+
import math
38
from collections import OrderedDict
49
from copy import deepcopy
510
from typing import Callable
@@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
8691

8792
if in_chans == 1:
8893
conv1_name = cfg['first_conv']
89-
_logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
94+
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
9095
conv1_weight = state_dict[conv1_name + '.weight']
91-
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
96+
# Some weights are in torch.half, ensure it's float for sum on CPU
97+
conv1_type = conv1_weight.dtype
98+
conv1_weight = conv1_weight.float()
99+
O, I, J, K = conv1_weight.shape
100+
if I > 3:
101+
assert conv1_weight.shape[1] % 3 == 0
102+
# For models with space2depth stems
103+
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
104+
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
105+
else:
106+
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
107+
conv1_weight = conv1_weight.to(conv1_type)
108+
state_dict[conv1_name + '.weight'] = conv1_weight
92109
elif in_chans != 3:
93-
assert False, "Invalid in_chans for pretrained weights"
110+
conv1_name = cfg['first_conv']
111+
conv1_weight = state_dict[conv1_name + '.weight']
112+
conv1_type = conv1_weight.dtype
113+
conv1_weight = conv1_weight.float()
114+
O, I, J, K = conv1_weight.shape
115+
if I != 3:
116+
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
117+
del state_dict[conv1_name + '.weight']
118+
strict = False
119+
else:
120+
# NOTE this strategy should be better than random init, but there could be other combinations of
121+
# the original RGB input layer weights that'd work better for specific cases.
122+
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
123+
repeat = int(math.ceil(in_chans / 3))
124+
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
125+
conv1_weight *= (3 / float(in_chans))
126+
conv1_weight = conv1_weight.to(conv1_type)
127+
state_dict[conv1_name + '.weight'] = conv1_weight
94128

95129
classifier_name = cfg['classifier']
96130
if num_classes == 1000 and cfg['num_classes'] == 1001:

0 commit comments

Comments
 (0)