Skip to content

Commit a99ec4e

Browse files
committed
A bunch more layer reorg, splitting many layers into own files. Improve torchscript compatibility.
1 parent 13746a3 commit a99ec4e

16 files changed

+479
-396
lines changed

timm/models/efficientnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from .feature_hooks import FeatureHooks
2828
from .registry import register_model
2929
from .helpers import load_pretrained
30-
from .layers import SelectAdaptivePool2d, select_conv2d
30+
from .layers import SelectAdaptivePool2d
31+
from timm.models.layers import select_conv2d
3132
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3233

3334

timm/models/efficientnet_blocks.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
2-
from functools import partial
3-
41
import torch
52
import torch.nn as nn
6-
import torch.nn.functional as F
3+
from torch.nn import functional as F
74
from .layers.activations import sigmoid
8-
from .layers.conv2d_layers import *
5+
from .layers import select_conv2d
96

107

118
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@@ -72,7 +69,7 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
7269
return make_divisible(channels, divisor, channel_min)
7370

7471

75-
def drop_connect(inputs, training=False, drop_connect_rate=0.):
72+
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
7673
"""Apply drop connect."""
7774
if not training:
7875
return inputs
@@ -160,7 +157,7 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
160157
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
161158
super(DepthwiseSeparableConv, self).__init__()
162159
norm_kwargs = norm_kwargs or {}
163-
self.has_se = se_ratio is not None and se_ratio > 0.
160+
has_se = se_ratio is not None and se_ratio > 0.
164161
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
165162
self.has_pw_act = pw_act # activation after point-wise conv
166163
self.drop_connect_rate = drop_connect_rate
@@ -171,9 +168,11 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
171168
self.act1 = act_layer(inplace=True)
172169

173170
# Squeeze-and-excitation
174-
if self.has_se:
171+
if has_se:
175172
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
176173
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
174+
else:
175+
self.se = None
177176

178177
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
179178
self.bn2 = norm_layer(out_chs, **norm_kwargs)
@@ -193,7 +192,7 @@ def forward(self, x):
193192
x = self.bn1(x)
194193
x = self.act1(x)
195194

196-
if self.has_se:
195+
if self.se is not None:
197196
x = self.se(x)
198197

199198
x = self.conv_pw(x)
@@ -219,7 +218,7 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
219218
norm_kwargs = norm_kwargs or {}
220219
conv_kwargs = conv_kwargs or {}
221220
mid_chs = make_divisible(in_chs * exp_ratio)
222-
self.has_se = se_ratio is not None and se_ratio > 0.
221+
has_se = se_ratio is not None and se_ratio > 0.
223222
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
224223
self.drop_connect_rate = drop_connect_rate
225224

@@ -236,9 +235,11 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
236235
self.act2 = act_layer(inplace=True)
237236

238237
# Squeeze-and-excitation
239-
if self.has_se:
238+
if has_se:
240239
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
241240
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
241+
else:
242+
self.se = None
242243

243244
# Point-wise linear projection
244245
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
@@ -269,7 +270,7 @@ def forward(self, x):
269270
x = self.act2(x)
270271

271272
# Squeeze-and-excitation
272-
if self.has_se:
273+
if self.se is not None:
273274
x = self.se(x)
274275

275276
# Point-wise linear projection
@@ -323,7 +324,7 @@ def forward(self, x):
323324
x = self.act2(x)
324325

325326
# Squeeze-and-excitation
326-
if self.has_se:
327+
if self.se is not None:
327328
x = self.se(x)
328329

329330
# Point-wise linear projection
@@ -350,7 +351,7 @@ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_ch
350351
mid_chs = make_divisible(fake_in_chs * exp_ratio)
351352
else:
352353
mid_chs = make_divisible(in_chs * exp_ratio)
353-
self.has_se = se_ratio is not None and se_ratio > 0.
354+
has_se = se_ratio is not None and se_ratio > 0.
354355
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
355356
self.drop_connect_rate = drop_connect_rate
356357

@@ -360,9 +361,11 @@ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_ch
360361
self.act1 = act_layer(inplace=True)
361362

362363
# Squeeze-and-excitation
363-
if self.has_se:
364+
if has_se:
364365
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
365366
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
367+
else:
368+
self.se = None
366369

367370
# Point-wise linear projection
368371
self.conv_pwl = select_conv2d(
@@ -389,7 +392,7 @@ def forward(self, x):
389392
x = self.act1(x)
390393

391394
# Squeeze-and-excitation
392-
if self.has_se:
395+
if self.se is not None:
393396
x = self.se(x)
394397

395398
# Point-wise linear projection

timm/models/efficientnet_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from copy import deepcopy
66

77
import torch.nn as nn
8-
from .layers.activations import sigmoid, HardSwish, Swish
8+
from .layers import CondConv2d, get_condconv_initializer
9+
from .layers.activations import HardSwish, Swish
910
from .efficientnet_blocks import *
1011

1112

timm/models/layers/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from .conv2d_layers import select_conv2d, MixedConv2d, CondConv2d, ConvBnAct, SelectiveKernelConv
1+
from .conv_bn_act import ConvBnAct
2+
from .mixed_conv2d import MixedConv2d
3+
from .cond_conv2d import CondConv2d, get_condconv_initializer
4+
from .select_conv2d import select_conv2d
5+
from .selective_kernel import SelectiveKernelConv
26
from .eca import EcaModule, CecaModule
37
from .activations import *
48
from .adaptive_avgmax_pool import \
59
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
6-
from .nn_ops import DropBlock2d, DropPath
10+
from .drop import DropBlock2d, DropPath
711
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
812
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

timm/models/layers/activations.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
""" Activations
2+
3+
A collection of activations fn and modules with a common interface so that they can
4+
easily be swapped. All have an `inplace` arg even if not used.
5+
6+
Hacked together by Ross Wightman
7+
"""
8+
9+
110
import torch
211
from torch import nn as nn
312
from torch.nn import functional as F
413

514

6-
_USE_MEM_EFFICIENT_ISH = True
15+
_USE_MEM_EFFICIENT_ISH = False
716
if _USE_MEM_EFFICIENT_ISH:
817
# This version reduces memory overhead of Swish during training by
918
# recomputing torch.sigmoid(x) in backward instead of saving it.
@@ -66,20 +75,20 @@ def mish(x, _inplace=False):
6675
return MishJitAutoFn.apply(x)
6776

6877
else:
69-
def swish(x, inplace=False):
78+
def swish(x, inplace: bool = False):
7079
"""Swish - Described in: https://arxiv.org/abs/1710.05941
7180
"""
7281
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
7382

7483

75-
def mish(x, _inplace=False):
84+
def mish(x, _inplace: bool = False):
7685
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
7786
"""
7887
return x.mul(F.softplus(x).tanh())
7988

8089

8190
class Swish(nn.Module):
82-
def __init__(self, inplace=False):
91+
def __init__(self, inplace: bool = False):
8392
super(Swish, self).__init__()
8493
self.inplace = inplace
8594

@@ -88,65 +97,65 @@ def forward(self, x):
8897

8998

9099
class Mish(nn.Module):
91-
def __init__(self, inplace=False):
100+
def __init__(self, inplace: bool = False):
92101
super(Mish, self).__init__()
93102
self.inplace = inplace
94103

95104
def forward(self, x):
96105
return mish(x, self.inplace)
97106

98107

99-
def sigmoid(x, inplace=False):
108+
def sigmoid(x, inplace: bool = False):
100109
return x.sigmoid_() if inplace else x.sigmoid()
101110

102111

103112
# PyTorch has this, but not with a consistent inplace argmument interface
104113
class Sigmoid(nn.Module):
105-
def __init__(self, inplace=False):
114+
def __init__(self, inplace: bool = False):
106115
super(Sigmoid, self).__init__()
107116
self.inplace = inplace
108117

109118
def forward(self, x):
110119
return x.sigmoid_() if self.inplace else x.sigmoid()
111120

112121

113-
def tanh(x, inplace=False):
122+
def tanh(x, inplace: bool = False):
114123
return x.tanh_() if inplace else x.tanh()
115124

116125

117126
# PyTorch has this, but not with a consistent inplace argmument interface
118127
class Tanh(nn.Module):
119-
def __init__(self, inplace=False):
128+
def __init__(self, inplace: bool = False):
120129
super(Tanh, self).__init__()
121130
self.inplace = inplace
122131

123132
def forward(self, x):
124133
return x.tanh_() if self.inplace else x.tanh()
125134

126135

127-
def hard_swish(x, inplace=False):
136+
def hard_swish(x, inplace: bool = False):
128137
inner = F.relu6(x + 3.).div_(6.)
129138
return x.mul_(inner) if inplace else x.mul(inner)
130139

131140

132141
class HardSwish(nn.Module):
133-
def __init__(self, inplace=False):
142+
def __init__(self, inplace: bool = False):
134143
super(HardSwish, self).__init__()
135144
self.inplace = inplace
136145

137146
def forward(self, x):
138147
return hard_swish(x, self.inplace)
139148

140149

141-
def hard_sigmoid(x, inplace=False):
150+
def hard_sigmoid(x, inplace: bool = False):
142151
if inplace:
143152
return x.add_(3.).clamp_(0., 6.).div_(6.)
144153
else:
145154
return F.relu6(x + 3.) / 6.
146155

147156

148157
class HardSigmoid(nn.Module):
149-
def __init__(self, inplace=False):
158+
def __init__(self, inplace: bool = False):
150159
super(HardSigmoid, self).__init__()
151160
self.inplace = inplace
152161

timm/models/layers/cond_conv2d.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
""" Conditional Convolution
2+
3+
Hacked together by Ross Wightman
4+
"""
5+
6+
import math
7+
from functools import partial
8+
import numpy as np
9+
import torch
10+
from torch import nn as nn
11+
from torch.nn import functional as F
12+
13+
from .conv2d_same import get_padding_value, conv2d_same
14+
from .conv_helpers import tup_pair
15+
16+
17+
def get_condconv_initializer(initializer, num_experts, expert_shape):
18+
def condconv_initializer(weight):
19+
"""CondConv initializer function."""
20+
num_params = np.prod(expert_shape)
21+
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
22+
weight.shape[1] != num_params):
23+
raise (ValueError(
24+
'CondConv variables must have shape [num_experts, num_params]'))
25+
for i in range(num_experts):
26+
initializer(weight[i].view(expert_shape))
27+
return condconv_initializer
28+
29+
30+
class CondConv2d(nn.Module):
31+
""" Conditional Convolution
32+
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
33+
34+
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
35+
https://github.com/pytorch/pytorch/issues/17983
36+
"""
37+
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
38+
39+
def __init__(self, in_channels, out_channels, kernel_size=3,
40+
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
41+
super(CondConv2d, self).__init__()
42+
43+
self.in_channels = in_channels
44+
self.out_channels = out_channels
45+
self.kernel_size = tup_pair(kernel_size)
46+
self.stride = tup_pair(stride)
47+
padding_val, is_padding_dynamic = get_padding_value(
48+
padding, kernel_size, stride=stride, dilation=dilation)
49+
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
50+
self.padding = tup_pair(padding_val)
51+
self.dilation = tup_pair(dilation)
52+
self.groups = groups
53+
self.num_experts = num_experts
54+
55+
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
56+
weight_num_param = 1
57+
for wd in self.weight_shape:
58+
weight_num_param *= wd
59+
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
60+
61+
if bias:
62+
self.bias_shape = (self.out_channels,)
63+
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
64+
else:
65+
self.register_parameter('bias', None)
66+
67+
self.reset_parameters()
68+
69+
def reset_parameters(self):
70+
init_weight = get_condconv_initializer(
71+
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
72+
init_weight(self.weight)
73+
if self.bias is not None:
74+
fan_in = np.prod(self.weight_shape[1:])
75+
bound = 1 / math.sqrt(fan_in)
76+
init_bias = get_condconv_initializer(
77+
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
78+
init_bias(self.bias)
79+
80+
def forward(self, x, routing_weights):
81+
B, C, H, W = x.shape
82+
weight = torch.matmul(routing_weights, self.weight)
83+
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
84+
weight = weight.view(new_weight_shape)
85+
bias = None
86+
if self.bias is not None:
87+
bias = torch.matmul(routing_weights, self.bias)
88+
bias = bias.view(B * self.out_channels)
89+
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
90+
x = x.view(1, B * C, H, W)
91+
if self.dynamic_padding:
92+
out = conv2d_same(
93+
x, weight, bias, stride=self.stride, padding=self.padding,
94+
dilation=self.dilation, groups=self.groups * B)
95+
else:
96+
out = F.conv2d(
97+
x, weight, bias, stride=self.stride, padding=self.padding,
98+
dilation=self.dilation, groups=self.groups * B)
99+
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
100+
101+
# Literal port (from TF definition)
102+
# x = torch.split(x, 1, 0)
103+
# weight = torch.split(weight, 1, 0)
104+
# if self.bias is not None:
105+
# bias = torch.matmul(routing_weights, self.bias)
106+
# bias = torch.split(bias, 1, 0)
107+
# else:
108+
# bias = [None] * B
109+
# out = []
110+
# for xi, wi, bi in zip(x, weight, bias):
111+
# wi = wi.view(*self.weight_shape)
112+
# if bi is not None:
113+
# bi = bi.view(*self.bias_shape)
114+
# out.append(self.conv_fn(
115+
# xi, wi, bi, stride=self.stride, padding=self.padding,
116+
# dilation=self.dilation, groups=self.groups))
117+
# out = torch.cat(out, 0)
118+
return out

0 commit comments

Comments
 (0)