Skip to content

Commit 08e30ce

Browse files
committed
Add Res2Net and Res2NeXt
1 parent d07bd6e commit 08e30ce

File tree

6 files changed

+192
-34
lines changed

6 files changed

+192
-34
lines changed

src/Metalhead.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ include("convnets/resnets/core.jl")
2626
include("convnets/resnets/resnet.jl")
2727
include("convnets/resnets/resnext.jl")
2828
include("convnets/resnets/seresnet.jl")
29+
include("convnets/resnets/res2net.jl")
2930
## Inceptions
3031
include("convnets/inception/googlenet.jl")
3132
include("convnets/inception/inceptionv3.jl")
@@ -57,16 +58,16 @@ include("pretrain.jl")
5758

5859
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
5960
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
60-
WideResNet, ResNeXt, SEResNet, SEResNeXt,
61+
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
6162
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6263
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
6364
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
6465
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt
6566

6667
# use Flux._big_show to pretty print large models
6768
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
68-
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
69-
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
69+
:Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4,
70+
:Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
7071
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
7172
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
7273
end

src/convnets/resnets/core.jl

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
2-
basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
3-
norm_layer = BatchNorm, revnorm = false,
4-
drop_block = identity, drop_path = identity,
5-
attn_fn = planes -> identity)
2+
basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3+
reduction_factor::Integer = 1, activation = relu,
4+
norm_layer = BatchNorm, revnorm::Bool = false,
5+
drop_block = identity, drop_path = identity,
6+
attn_fn = planes -> identity)
67
78
Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
89
@@ -11,10 +12,11 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385
1112
- `inplanes`: number of input feature maps
1213
- `planes`: number of feature maps for the block
1314
- `stride`: the stride of the block
14-
- `reduction_factor`: the factor by which the input feature maps
15-
are reduced before the first convolution.
15+
- `reduction_factor`: the factor by which the input feature maps are reduced before
16+
the first convolution.
1617
- `activation`: the activation function to use.
1718
- `norm_layer`: the normalization layer to use.
19+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
1820
- `drop_block`: the drop block layer
1921
- `drop_path`: the drop path layer
2022
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -36,11 +38,12 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3638
end
3739

3840
"""
39-
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
40-
reduction_factor = 1, activation = relu,
41-
norm_layer = BatchNorm, revnorm = false,
42-
drop_block = identity, drop_path = identity,
43-
attn_fn = planes -> identity)
41+
bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
42+
cardinality::Integer = 1, base_width::Integer = 64,
43+
reduction_factor::Integer = 1, activation = relu,
44+
norm_layer = BatchNorm, revnorm::Bool = false,
45+
drop_block = identity, drop_path = identity,
46+
attn_fn = planes -> identity)
4447
4548
Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
4649
@@ -55,6 +58,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.
5558
convolution.
5659
- `activation`: the activation function to use.
5760
- `norm_layer`: the normalization layer to use.
61+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
5862
- `drop_block`: the drop block layer
5963
- `drop_path`: the drop path layer
6064
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -153,7 +157,8 @@ on how to use this function.
153157
shows peformance improvements over the `:deep` stem in some cases.
154158
155159
- `inchannels`: The number of channels in the input.
156-
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
160+
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution +
161+
normalization with a stride of two.
157162
- `norm_layer`: The normalisation layer used in the stem.
158163
- `activation`: The activation function used in the stem.
159164
"""
@@ -253,8 +258,6 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
253258
stride = stride_fn(stage_idx, block_idx)
254259
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
255260
downsample_tuple[1] : downsample_tuple[2]
256-
# DropBlock, DropPath both take in rates based on a linear scaling schedule
257-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
258261
drop_path = DropPath(pathschedule[schedule_idx])
259262
drop_block = DropBlock(blockschedule[schedule_idx])
260263
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
@@ -289,35 +292,46 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte
289292
return Chain(backbone, classifier_fn(nfeaturemaps))
290293
end
291294

292-
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};
293-
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity),
295+
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
296+
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
294297
cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64,
295298
reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256),
296-
inchannels::Integer = 3, stem_fn = resnet_stem,
297-
connection = addact, activation = relu, norm_layer = BatchNorm,
298-
revnorm::Bool = false, attn_fn = planes -> identity,
299-
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
300-
drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0,
301-
nclasses::Integer = 1000)
299+
inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact,
300+
activation = relu, norm_layer = BatchNorm, revnorm::Bool = false,
301+
attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)),
302+
use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0,
303+
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
302304
# Build stem
303305
stem = stem_fn(; inchannels)
304306
# Block builder
305-
if block_type == :basicblock
307+
if block_type === :basicblock
306308
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
307309
@assert base_width==64 "Base width must be 64 for `basicblock`"
308310
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
309311
activation, norm_layer, revnorm, attn_fn,
310312
drop_block_rate, drop_path_rate,
311313
stride_fn = resnet_stride,
312314
planes_fn = resnet_planes,
313-
downsample_tuple = downsample_opt)
314-
elseif block_type == :bottleneck
315+
downsample_tuple = downsample_opt,
316+
kwargs...)
317+
elseif block_type === :bottleneck
315318
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
316-
reduction_factor, activation, norm_layer,
317-
revnorm, attn_fn, drop_block_rate, drop_path_rate,
319+
reduction_factor, activation, norm_layer, revnorm,
320+
attn_fn, drop_block_rate, drop_path_rate,
318321
stride_fn = resnet_stride,
319322
planes_fn = resnet_planes,
320-
downsample_tuple = downsample_opt)
323+
downsample_tuple = downsample_opt,
324+
kwargs...)
325+
elseif block_type === :bottle2neck
326+
@assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`"
327+
@assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`"
328+
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`"
329+
get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width,
330+
activation, norm_layer, revnorm, attn_fn,
331+
stride_fn = resnet_stride,
332+
planes_fn = resnet_planes,
333+
downsample_tuple = downsample_opt,
334+
kwargs...)
321335
else
322336
# TODO: write better message when we have link to dev docs for resnet
323337
throw(ArgumentError("Unknown block type $block_type"))

src/convnets/resnets/res2net.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
3+
cardinality::Integer = 1, base_width::Integer = 26,
4+
scale::Integer = 4, activation = relu, norm_layer = BatchNorm,
5+
revnorm::Bool = false, attn_fn = planes -> identity)
6+
7+
Creates a bottleneck block as described in the Res2Net paper.
8+
([reference](https://arxiv.org/abs/1904.01169))
9+
10+
# Arguments
11+
- `inplanes`: number of input feature maps
12+
- `planes`: number of feature maps for the block
13+
- `stride`: the stride of the block
14+
- `cardinality`: the number of groups in the 3x3 convolutions.
15+
- `base_width`: the number of output feature maps for each convolutional group.
16+
- `scale`: the number of feature groups in the block. See the [paper](https://arxiv.org/abs/1904.01169)
17+
for more details.
18+
- `activation`: the activation function to use.
19+
- `norm_layer`: the normalization layer to use.
20+
- `revnorm`: set to `true` to place the batch norm before the convolution
21+
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
22+
"""
23+
function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
24+
cardinality::Integer = 1, base_width::Integer = 26,
25+
scale::Integer = 4, activation = relu, norm_layer = BatchNorm,
26+
revnorm::Bool = false, attn_fn = planes -> identity)
27+
width = fld(planes * base_width, 64) * cardinality
28+
outplanes = planes * 4
29+
is_first = stride > 1
30+
pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity
31+
conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride,
32+
pad = 1, groups = cardinality, bias = false)...)
33+
for _ in 1:(max(1, scale - 1))]
34+
reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) :
35+
Parallel(cat_channels, identity, PairwiseFusion(+, conv_bns...))
36+
tuplify(x) = is_first ? tuple(x...) : tuple(x[1], tuple(x[2:end]...))
37+
return Chain(conv_norm((1, 1), inplanes => width * scale, activation;
38+
norm_layer, revnorm, bias = false)...,
39+
chunk$(; size = width, dims = 3),
40+
tuplify, reslayer,
41+
conv_norm((1, 1), width * scale => outplanes, activation;
42+
norm_layer, revnorm, bias = false)...,
43+
attn_fn(outplanes))
44+
end
45+
46+
function bottle2neck_builder(block_repeats::AbstractVector{<:Integer};
47+
inplanes::Integer = 64, cardinality::Integer = 1,
48+
base_width::Integer = 26, scale::Integer = 4,
49+
expansion::Integer = 4, norm_layer = BatchNorm,
50+
revnorm::Bool = false, activation = relu,
51+
attn_fn = planes -> identity,
52+
stride_fn = resnet_stride, planes_fn = resnet_planes,
53+
downsample_tuple = (downsample_conv, downsample_identity))
54+
planes_vec = collect(planes_fn(block_repeats))
55+
# closure over `idxs`
56+
function get_layers(stage_idx::Integer, block_idx::Integer)
57+
# This is needed for block `inplanes` and `planes` calculations
58+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
59+
planes = planes_vec[schedule_idx]
60+
inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion
61+
# `resnet_stride` is a callback that the user can tweak to change the stride of the
62+
# blocks. It defaults to the standard behaviour as in the paper
63+
stride = stride_fn(stage_idx, block_idx)
64+
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
65+
downsample_tuple[1] : downsample_tuple[2]
66+
block = bottle2neck(inplanes, planes; stride, cardinality, base_width, scale,
67+
activation, norm_layer, revnorm, attn_fn)
68+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
69+
revnorm)
70+
return block, downsample
71+
end
72+
return get_layers
73+
end
74+
75+
"""
76+
Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
77+
base_width::Integer = 26, inchannels::Integer = 3,
78+
nclasses::Integer = 1000)
79+
80+
Creates a Res2Net model with the specified depth, scale, and base width.
81+
([reference](https://arxiv.org/abs/1904.01169))
82+
83+
# Arguments
84+
- `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model.
85+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
86+
- `scale`: the number of feature groups in the block. See the
87+
[paper](https://arxiv.org/abs/1904.01169) for more details.
88+
- `base_width`: the number of feature maps in each group.
89+
- `inchannels`: the number of input channels.
90+
- `nclasses`: the number of output classes
91+
"""
92+
struct Res2Net
93+
layers::Any
94+
end
95+
96+
function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
97+
base_width::Integer = 26, inchannels::Integer = 3,
98+
nclasses::Integer = 1000)
99+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
100+
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2], :C; base_width, scale,
101+
inchannels, nclasses)
102+
if pretrain
103+
loadpretrain!(layers, string("Res2Net", depth, "_", base_width, "x", scale))
104+
end
105+
return ResNet(layers)
106+
end
107+
108+
"""
109+
Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
110+
base_width::Integer = 4, cardinality::Integer = 8,
111+
inchannels::Integer = 3, nclasses::Integer = 1000)
112+
113+
Creates a Res2NeXt model with the specified depth, scale, base width and cardinality.
114+
([reference](https://arxiv.org/abs/1904.01169))
115+
116+
# Arguments
117+
- `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model.
118+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
119+
- `scale`: the number of feature groups in the block. See the
120+
[paper](https://arxiv.org/abs/1904.01169) for more details.
121+
- `base_width`: the number of feature maps in each group.
122+
- `cardinality`: the number of groups in the 3x3 convolutions.
123+
- `inchannels`: the number of input channels.
124+
- `nclasses`: the number of output classes
125+
"""
126+
struct Res2NeXt
127+
layers::Any
128+
end
129+
130+
function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
131+
base_width::Integer = 4, cardinality::Integer = 8,
132+
inchannels::Integer = 3, nclasses::Integer = 1000)
133+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
134+
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2], :C; base_width, scale,
135+
cardinality, inchannels, nclasses)
136+
if pretrain
137+
loadpretrain!(layers,
138+
string("Res2NeXt", depth, "_", base_width, "x", cardinality,
139+
"x", scale))
140+
end
141+
return ResNet(layers)
142+
end

src/convnets/resnets/resnet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ end
5757

5858
function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
5959
nclasses::Integer = 1000)
60-
_checkconfig(depth, [50, 101])
60+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
6161
layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
6262
if pretrain
6363
loadpretrain!(layers, string("WideResNet", depth))

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Create a convolution + batch normalization pair with activation.
1616
- `outplanes`: number of output feature maps
1717
- `activation`: the activation function for the final layer
1818
- `norm_layer`: the normalization layer used
19-
- `revnorm`: set to `true` to place the batch norm before the convolution
19+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
2020
- `preact`: set to `true` to place the activation function before the batch norm
2121
(only compatible with `revnorm = false`)
2222
- `use_norm`: set to `false` to disable normalization

src/mixers/mlpmixer.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ struct MLPMixer
5656
end
5757
@functor MLPMixer
5858

59-
function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16),
59+
function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224),
60+
patch_size::Dims{2} = (16, 16),
6061
inchannels::Integer = 3, nclasses::Integer = 1000)
6162
_checkconfig(config, keys(MIXER_CONFIGS))
6263
layers = mlpmixer(mixerblock, imsize; patch_size, MIXER_CONFIGS[config]..., inchannels,

0 commit comments

Comments
 (0)