|
| 1 | +""" Model creation / weight loading / state_dict helpers |
| 2 | +
|
| 3 | +Hacked together by / Copyright 2020 Ross Wightman |
| 4 | +""" |
1 | 5 | import logging
|
2 | 6 | import os
|
| 7 | +import math |
3 | 8 | from collections import OrderedDict
|
4 | 9 | from copy import deepcopy
|
5 | 10 | from typing import Callable
|
@@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
86 | 91 |
|
87 | 92 | if in_chans == 1:
|
88 | 93 | conv1_name = cfg['first_conv']
|
89 |
| - _logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) |
| 94 | + _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) |
90 | 95 | conv1_weight = state_dict[conv1_name + '.weight']
|
91 |
| - state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) |
| 96 | + # Some weights are in torch.half, ensure it's float for sum on CPU |
| 97 | + conv1_type = conv1_weight.dtype |
| 98 | + conv1_weight = conv1_weight.float() |
| 99 | + O, I, J, K = conv1_weight.shape |
| 100 | + if I > 3: |
| 101 | + assert conv1_weight.shape[1] % 3 == 0 |
| 102 | + # For models with space2depth stems |
| 103 | + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) |
| 104 | + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) |
| 105 | + else: |
| 106 | + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) |
| 107 | + conv1_weight = conv1_weight.to(conv1_type) |
| 108 | + state_dict[conv1_name + '.weight'] = conv1_weight |
92 | 109 | elif in_chans != 3:
|
93 |
| - assert False, "Invalid in_chans for pretrained weights" |
| 110 | + conv1_name = cfg['first_conv'] |
| 111 | + conv1_weight = state_dict[conv1_name + '.weight'] |
| 112 | + conv1_type = conv1_weight.dtype |
| 113 | + conv1_weight = conv1_weight.float() |
| 114 | + O, I, J, K = conv1_weight.shape |
| 115 | + if I != 3: |
| 116 | + _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) |
| 117 | + del state_dict[conv1_name + '.weight'] |
| 118 | + strict = False |
| 119 | + else: |
| 120 | + # NOTE this strategy should be better than random init, but there could be other combinations of |
| 121 | + # the original RGB input layer weights that'd work better for specific cases. |
| 122 | + _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) |
| 123 | + repeat = int(math.ceil(in_chans / 3)) |
| 124 | + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] |
| 125 | + conv1_weight *= (3 / float(in_chans)) |
| 126 | + conv1_weight = conv1_weight.to(conv1_type) |
| 127 | + state_dict[conv1_name + '.weight'] = conv1_weight |
94 | 128 |
|
95 | 129 | classifier_name = cfg['classifier']
|
96 | 130 | if num_classes == 1000 and cfg['num_classes'] == 1001:
|
|
0 commit comments