Skip to content

Commit dfa9298

Browse files
committed
Add MixNet (https://arxiv.org/abs/1907.09595) with pretrained weights converted from Tensorflow impl
* refactor 'same' convolution and add helper to use MixedConv2d when needed * improve performance of 'same' padding for cases that can be handled statically * add support for extra exp, pw, and dw kernel specs with grouping support to decoder/string defs for MixNet * shuffle some args for a bit more consistency, a little less clutter overall in gen_efficientnet.py
1 parent 7a92caa commit dfa9298

File tree

6 files changed

+403
-157
lines changed

6 files changed

+403
-157
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,17 @@ I've included a few of my favourite models, but this is not an exhaustive collec
3131
* PNasNet & NASNet-A (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
3232
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
3333
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
34-
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
34+
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
3535
* EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
36+
* MixNet (https://arxiv.org/abs/1907.09595) -- validated, compat with TF weights
3637
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
3738
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
3839
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
3940
* MobileNet-V3 (https://arxiv.org/abs/1905.02244) -- pretrained model good, still no official impl to verify against
4041
* ChamNet (https://arxiv.org/abs/1812.08934) -- specific arch details hard to find, currently an educated guess
4142
* FBNet-C (https://arxiv.org/abs/1812.03443) -- TODO A/B variants
4243
* Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant
43-
44+
4445
Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase
4546
creation fn for the model you'd like.
4647

@@ -118,11 +119,17 @@ I've leveraged the training scripts in this repository to train a few of the mod
118119
| gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | |
119120
| gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | |
120121
| gluon_resnet50_v1d | 79.074 (20.926) | 94.476 (5.524) | 25.58 | bicubic | |
122+
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
123+
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
121124
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | |
122125
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | |
123126
| gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | |
127+
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
128+
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
124129
| tf_efficientnet_b0 *tfp | 76.828 (23.172) | 93.226 (6.774) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
125130
| tf_efficientnet_b0 | 76.528 (23.472) | 93.010 (6.990) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
131+
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
132+
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
126133
| gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | |
127134
| gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | |
128135

timm/data/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def create_loader(
112112

113113
if tf_preprocessing and use_prefetcher:
114114
from timm.data.tf_preprocessing import TfPreprocessTransform
115-
transform = TfPreprocessTransform(is_training=is_training, size=img_size)
115+
transform = TfPreprocessTransform(
116+
is_training=is_training, size=img_size, interpolation=interpolation)
116117
else:
117118
if is_training:
118119
transform = transforms_imagenet_train(

timm/data/tf_preprocessing.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _at_least_x_are_equal(a, b, x):
8383
return tf.greater_equal(tf.reduce_sum(match), x)
8484

8585

86-
def _decode_and_random_crop(image_bytes, image_size):
86+
def _decode_and_random_crop(image_bytes, image_size, resize_method):
8787
"""Make a random crop of image_size."""
8888
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
8989
image = distorted_bounding_box_crop(
@@ -100,13 +100,12 @@ def _decode_and_random_crop(image_bytes, image_size):
100100
image = tf.cond(
101101
bad,
102102
lambda: _decode_and_center_crop(image_bytes, image_size),
103-
lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda
104-
[image_size, image_size])[0])
103+
lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0])
105104

106105
return image
107106

108107

109-
def _decode_and_center_crop(image_bytes, image_size):
108+
def _decode_and_center_crop(image_bytes, image_size, resize_method):
110109
"""Crops to center of image with padding then scales image_size."""
111110
shape = tf.image.extract_jpeg_shape(image_bytes)
112111
image_height = shape[0]
@@ -122,7 +121,7 @@ def _decode_and_center_crop(image_bytes, image_size):
122121
crop_window = tf.stack([offset_height, offset_width,
123122
padded_center_crop_size, padded_center_crop_size])
124123
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
125-
image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
124+
image = tf.image.resize([image], [image_size, image_size], resize_method)[0]
126125

127126
return image
128127

@@ -133,37 +132,41 @@ def _flip(image):
133132
return image
134133

135134

136-
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
135+
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
137136
"""Preprocesses the given image for evaluation.
138137
139138
Args:
140139
image_bytes: `Tensor` representing an image binary of arbitrary size.
141140
use_bfloat16: `bool` for whether to use bfloat16.
142141
image_size: image size.
142+
interpolation: image interpolation method
143143
144144
Returns:
145145
A preprocessed image `Tensor`.
146146
"""
147-
image = _decode_and_random_crop(image_bytes, image_size)
147+
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
148+
image = _decode_and_random_crop(image_bytes, image_size, resize_method)
148149
image = _flip(image)
149150
image = tf.reshape(image, [image_size, image_size, 3])
150151
image = tf.image.convert_image_dtype(
151152
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
152153
return image
153154

154155

155-
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
156+
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
156157
"""Preprocesses the given image for evaluation.
157158
158159
Args:
159160
image_bytes: `Tensor` representing an image binary of arbitrary size.
160161
use_bfloat16: `bool` for whether to use bfloat16.
161162
image_size: image size.
163+
interpolation: image interpolation method
162164
163165
Returns:
164166
A preprocessed image `Tensor`.
165167
"""
166-
image = _decode_and_center_crop(image_bytes, image_size)
168+
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
169+
image = _decode_and_center_crop(image_bytes, image_size, resize_method)
167170
image = tf.reshape(image, [image_size, image_size, 3])
168171
image = tf.image.convert_image_dtype(
169172
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
@@ -173,29 +176,32 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
173176
def preprocess_image(image_bytes,
174177
is_training=False,
175178
use_bfloat16=False,
176-
image_size=IMAGE_SIZE):
179+
image_size=IMAGE_SIZE,
180+
interpolation='bicubic'):
177181
"""Preprocesses the given image.
178182
179183
Args:
180184
image_bytes: `Tensor` representing an image binary of arbitrary size.
181185
is_training: `bool` for whether the preprocessing is for training.
182186
use_bfloat16: `bool` for whether to use bfloat16.
183187
image_size: image size.
188+
interpolation: image interpolation method
184189
185190
Returns:
186191
A preprocessed image `Tensor` with value range of [0, 255].
187192
"""
188193
if is_training:
189-
return preprocess_for_train(image_bytes, use_bfloat16, image_size)
194+
return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation)
190195
else:
191-
return preprocess_for_eval(image_bytes, use_bfloat16, image_size)
196+
return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation)
192197

193198

194199
class TfPreprocessTransform:
195200

196-
def __init__(self, is_training=False, size=224):
201+
def __init__(self, is_training=False, size=224, interpolation='bicubic'):
197202
self.is_training = is_training
198203
self.size = size[0] if isinstance(size, tuple) else size
204+
self.interpolation = interpolation
199205
self._image_bytes = None
200206
self.process_image = self._build_tf_graph()
201207
self.sess = None
@@ -206,7 +212,8 @@ def _build_tf_graph(self):
206212
shape=[],
207213
dtype=tf.string,
208214
)
209-
img = preprocess_image(self._image_bytes, self.is_training, False, self.size)
215+
img = preprocess_image(
216+
self._image_bytes, self.is_training, False, self.size, self.interpolation)
210217
return img
211218

212219
def __call__(self, image_bytes):

timm/models/conv2d_helpers.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
5+
6+
7+
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
8+
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
9+
10+
11+
def _get_padding(kernel_size, stride=1, dilation=1, **_):
12+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
13+
return padding
14+
15+
16+
def _calc_same_pad(i, k, s, d):
17+
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
18+
19+
20+
def _split_channels(num_chan, num_groups):
21+
split = [num_chan // num_groups for _ in range(num_groups)]
22+
split[0] += num_chan - sum(split)
23+
return split
24+
25+
26+
class Conv2dSame(nn.Conv2d):
27+
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
28+
"""
29+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
30+
padding=0, dilation=1, groups=1, bias=True):
31+
super(Conv2dSame, self).__init__(
32+
in_channels, out_channels, kernel_size, stride, 0, dilation,
33+
groups, bias)
34+
35+
def forward(self, x):
36+
ih, iw = x.size()[-2:]
37+
kh, kw = self.weight.size()[-2:]
38+
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
39+
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
40+
if pad_h > 0 or pad_w > 0:
41+
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
42+
return F.conv2d(x, self.weight, self.bias, self.stride,
43+
self.padding, self.dilation, self.groups)
44+
45+
46+
def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
47+
padding = kwargs.pop('padding', '')
48+
kwargs.setdefault('bias', False)
49+
if isinstance(padding, str):
50+
# for any string padding, the padding will be calculated for you, one of three ways
51+
padding = padding.lower()
52+
if padding == 'same':
53+
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
54+
if _is_static_pad(kernel_size, **kwargs):
55+
# static case, no extra overhead
56+
padding = _get_padding(kernel_size, **kwargs)
57+
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
58+
else:
59+
# dynamic padding
60+
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
61+
elif padding == 'valid':
62+
# 'VALID' padding, same as padding=0
63+
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
64+
else:
65+
# Default to PyTorch style 'same'-ish symmetric padding
66+
padding = _get_padding(kernel_size, **kwargs)
67+
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
68+
else:
69+
# padding was specified as a number or pair
70+
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
71+
72+
73+
class MixedConv2d(nn.Module):
74+
""" Mixed Grouped Convolution
75+
Based on MDConv and GroupedConv in MixNet impl:
76+
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
77+
"""
78+
79+
def __init__(self, in_channels, out_channels, kernel_size=3,
80+
stride=1, padding='', dilated=False, depthwise=False, **kwargs):
81+
super(MixedConv2d, self).__init__()
82+
83+
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
84+
num_groups = len(kernel_size)
85+
in_splits = _split_channels(in_channels, num_groups)
86+
out_splits = _split_channels(out_channels, num_groups)
87+
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
88+
d = 1
89+
# FIXME make compat with non-square kernel/dilations/strides
90+
if stride == 1 and dilated:
91+
d, k = (k - 1) // 2, 3
92+
conv_groups = out_ch if depthwise else 1
93+
# use add_module to keep key space clean
94+
self.add_module(
95+
str(idx),
96+
conv2d_pad(
97+
in_ch, out_ch, k, stride=stride,
98+
padding=padding, dilation=d, groups=conv_groups, **kwargs)
99+
)
100+
self.splits = in_splits
101+
102+
def forward(self, x):
103+
x_split = torch.split(x, self.splits, 1)
104+
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
105+
x = torch.cat(x_out, 1)
106+
return x
107+
108+
109+
# helper method
110+
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
111+
assert 'groups' not in kwargs # only use 'depthwise' bool arg
112+
if isinstance(kernel_size, list):
113+
# We're going to use only lists for defining the MixedConv2d kernel groups,
114+
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
115+
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
116+
else:
117+
depthwise = kwargs.pop('depthwise', False)
118+
groups = out_chs if depthwise else 1
119+
return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
120+

timm/models/conv2d_same.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)