Skip to content

Commit f1d5f8a

Browse files
committed
Update comments for Selective Kernel and DropBlock/Path impl, add skresnet34 weights
1 parent 569419b commit f1d5f8a

File tree

5 files changed

+70
-14
lines changed

5 files changed

+70
-14
lines changed

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
from .activations import *
1313
from .adaptive_avgmax_pool import \
1414
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
15-
from .drop import DropBlock2d, DropPath
15+
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1616
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
1717
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

timm/models/layers/drop.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
33
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
44
5+
Papers:
6+
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7+
8+
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9+
10+
Code:
11+
DropBlock impl inspired by two Tensorflow impl that I liked:
12+
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13+
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14+
515
Hacked together by Ross Wightman
616
"""
717
import torch
@@ -11,9 +21,15 @@
1121
import math
1222

1323

14-
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
24+
def drop_block_2d(x, drop_prob=0.1, training=False, block_size=7, gamma_scale=1.0, drop_with_noise=False):
1525
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
26+
27+
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
28+
runs with success, but needs further validation and possibly optimization for lower runtime impact.
29+
1630
"""
31+
if drop_prob == 0. or not training:
32+
return x
1733
_, _, height, width = x.shape
1834
total_size = width * height
1935
clipped_block_size = min(block_size, min(width, height))
@@ -60,14 +76,21 @@ def __init__(self,
6076
self.with_noise = with_noise
6177

6278
def forward(self, x):
63-
if not self.training or not self.drop_prob:
64-
return x
65-
return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise)
79+
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
80+
81+
82+
def drop_path(x, drop_prob=0., training=False):
83+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
6684
85+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
86+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
87+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
88+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
89+
'survival rate' as the argument.
6790
68-
def drop_path(x, drop_prob=0.):
69-
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
7091
"""
92+
if drop_prob == 0. or not training:
93+
return x
7194
keep_prob = 1 - drop_prob
7295
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
7396
random_tensor.floor_() # binarize
@@ -76,13 +99,11 @@ def drop_path(x, drop_prob=0.):
7699

77100

78101
class DropPath(nn.ModuleDict):
79-
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
102+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
80103
"""
81104
def __init__(self, drop_prob=None):
82105
super(DropPath, self).__init__()
83106
self.drop_prob = drop_prob
84107

85108
def forward(self, x):
86-
if not self.training or not self.drop_prob:
87-
return x
88-
return drop_path(x, self.drop_prob)
109+
return drop_path(x, self.drop_prob, self.training)

timm/models/layers/selective_kernel.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ def _kernel_valid(k):
2121
class SelectiveKernelAttn(nn.Module):
2222
def __init__(self, channels, num_paths=2, attn_channels=32,
2323
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
24+
""" Selective Kernel Attention Module
25+
26+
Selective Kernel attention mechanism factored out into its own module.
27+
28+
"""
2429
super(SelectiveKernelAttn, self).__init__()
2530
self.num_paths = num_paths
2631
self.pool = nn.AdaptiveAvgPool2d(1)
@@ -48,8 +53,33 @@ class SelectiveKernelConv(nn.Module):
4853
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
4954
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
5055
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
56+
""" Selective Kernel Convolution Module
57+
58+
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
59+
60+
Largest change is the input split, which divides the input channels across each convolution path, this can
61+
be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
62+
the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
63+
a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
64+
65+
Args:
66+
in_channels (int): module input (feature) channel count
67+
out_channels (int): module output (feature) channel count
68+
kernel_size (int, list): kernel size for each convolution branch
69+
stride (int): stride for convolutions
70+
dilation (int): dilation for module as a whole, impacts dilation of each branch
71+
groups (int): number of groups for each branch
72+
attn_reduction (int, float): reduction factor for attention features
73+
min_attn_channels (int): minimum attention feature channels
74+
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
75+
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
76+
can be viewed as grouping by path, output expands to module out_channels count
77+
drop_block (nn.Module): drop block module
78+
act_layer (nn.Module): activation layer to use
79+
norm_layer (nn.Module): batchnorm/norm layer to use
80+
"""
5181
super(SelectiveKernelConv, self).__init__()
52-
kernel_size = kernel_size or [3, 5]
82+
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
5383
_kernel_valid(kernel_size)
5484
if not isinstance(kernel_size, list):
5585
kernel_size = [kernel_size] * 2

timm/models/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
382382
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
383383

384384
# Feature Blocks
385-
dp = DropPath(drop_path_rate) if drop_block_rate else None
385+
dp = DropPath(drop_path_rate) if drop_path_rate else None
386386
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
387387
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
388388
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4

timm/models/sknet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
33
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
44
5+
This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
6+
and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
7+
to the original paper with some modifications of my own to better balance param count vs accuracy.
8+
59
Hacked together by Ross Wightman
610
"""
711
import math
@@ -29,7 +33,8 @@ def _cfg(url='', **kwargs):
2933
default_cfgs = {
3034
'skresnet18': _cfg(
3135
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
32-
'skresnet34': _cfg(url=''),
36+
'skresnet34': _cfg(
37+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
3338
'skresnet50': _cfg(),
3439
'skresnet50d': _cfg(),
3540
'skresnext50_32x4d': _cfg(

0 commit comments

Comments
 (0)