Skip to content

Commit 569419b

Browse files
committed
Tweak some comments, add SKNet models with weights to sotabench, remove an unused branch
1 parent 91e2b33 commit 569419b

File tree

7 files changed

+40
-24
lines changed

7 files changed

+40
-24
lines changed

sotabench.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
5656
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
5757
_entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
5858
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
59-
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
60-
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
59+
6160
_entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'),
6261
_entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
6362
_entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
@@ -82,14 +81,22 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
8281
_entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
8382
_entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2,
8483
model_desc='Ported from GluonCV Model Zoo'),
84+
8585
_entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"),
8686
_entry('mixnet_l', 'MixNet-L', '1907.09595'),
8787
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
8888
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
89+
90+
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
91+
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
8992
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
93+
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
94+
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
95+
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
9096
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
9197
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
9298
'paper as closely as possible.'),
99+
93100
_entry('resnet18', 'ResNet-18', '1812.01187'),
94101
_entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'),
95102
_entry('resnet26d', 'ResNet-26-D', '1812.01187',
@@ -103,7 +110,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
103110
_entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187',
104111
model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with "
105112
"SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"),
106-
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
113+
107114
_entry('seresnet18', 'SE-ResNet-18', '1709.01507'),
108115
_entry('seresnet34', 'SE-ResNet-34', '1709.01507'),
109116
_entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507',
@@ -114,8 +121,9 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
114121
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'),
115122
_entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187',
116123
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'),
117-
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
118-
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
124+
125+
_entry('skresnet18', 'SK-ResNet-18', '1903.06586'),
126+
_entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'),
119127

120128
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
121129
model_desc='Ported from official Google AI Tensorflow weights'),

timm/models/layers/cbam.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
33
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
44
5+
WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
6+
some tasks, especially fine-grained it seems. I may end up removing this impl.
7+
58
Hacked together by Ross Wightman
69
"""
710

timm/models/layers/cond_conv2d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
""" Conditional Convolution
1+
""" PyTorch Conditionally Parameterized Convolution (CondConv)
2+
3+
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
4+
(https://arxiv.org/abs/1904.04971)
25
36
Hacked together by Ross Wightman
47
"""
@@ -28,7 +31,7 @@ def condconv_initializer(weight):
2831

2932

3033
class CondConv2d(nn.Module):
31-
""" Conditional Convolution
34+
""" Conditionally Parameterized Convolution
3235
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
3336
3437
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:

timm/models/layers/eca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class EcaModule(nn.Module):
4242
"""Constructs an ECA module.
4343
4444
Args:
45-
channel: Number of channels of the input feature map for use in adaptive kernel sizes
45+
channels: Number of channels of the input feature map for use in adaptive kernel sizes
4646
for actual calculations according to channel.
4747
gamma, beta: when channel is given parameters of mapping function
4848
refer to original paper https://arxiv.org/pdf/1910.03151.pdf

timm/models/layers/mixed_conv2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
""" Conditional Convolution
1+
""" PyTorch Mixed Convolution
2+
3+
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
24
35
Hacked together by Ross Wightman
46
"""

timm/models/layers/selective_kernel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
""" Selective Kernel Convolution Attention
1+
""" Selective Kernel Convolution/Attention
2+
3+
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
24
35
Hacked together by Ross Wightman
46
"""

timm/models/sknet.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
""" Selective Kernel Networks (ResNet base)
2+
3+
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
4+
5+
Hacked together by Ross Wightman
6+
"""
17
import math
28

39
from torch import nn as nn
@@ -47,19 +53,11 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
4753
outplanes = planes * self.expansion
4854
first_dilation = first_dilation or dilation
4955

50-
_selective_first = True # FIXME temporary, for experiments
51-
if _selective_first:
52-
self.conv1 = SelectiveKernelConv(
53-
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
54-
conv_kwargs['act_layer'] = None
55-
self.conv2 = ConvBnAct(
56-
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
57-
else:
58-
self.conv1 = ConvBnAct(
59-
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
60-
conv_kwargs['act_layer'] = None
61-
self.conv2 = SelectiveKernelConv(
62-
first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs)
56+
self.conv1 = SelectiveKernelConv(
57+
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
58+
conv_kwargs['act_layer'] = None
59+
self.conv2 = ConvBnAct(
60+
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
6361
self.se = create_attn(attn_layer, outplanes)
6462
self.act = act_layer(inplace=True)
6563
self.downsample = downsample
@@ -222,7 +220,7 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
222220
@register_model
223221
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
224222
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
225-
the SKNet50 model in the Select Kernel Paper
223+
the SKNet-50 model in the Select Kernel Paper
226224
"""
227225
default_cfg = default_cfgs['skresnext50_32x4d']
228226
model = ResNet(

0 commit comments

Comments
 (0)