Skip to content

Commit b54f594

Browse files
committed
Add Res2Net and Res2NeXt
1 parent 59e1ef4 commit b54f594

File tree

6 files changed

+193
-36
lines changed

6 files changed

+193
-36
lines changed

src/Metalhead.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ include("convnets/resnets/core.jl")
2626
include("convnets/resnets/resnet.jl")
2727
include("convnets/resnets/resnext.jl")
2828
include("convnets/resnets/seresnet.jl")
29+
include("convnets/resnets/res2net.jl")
2930
## Inceptions
3031
include("convnets/inception/googlenet.jl")
3132
include("convnets/inception/inceptionv3.jl")
@@ -57,16 +58,16 @@ include("pretrain.jl")
5758

5859
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
5960
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
60-
WideResNet, ResNeXt, SEResNet, SEResNeXt,
61+
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
6162
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6263
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
6364
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
6465
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt
6566

6667
# use Flux._big_show to pretty print large models
6768
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
68-
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
69-
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
69+
:Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4,
70+
:Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
7071
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
7172
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
7273
end

src/convnets/resnets/core.jl

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
2-
basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
3-
norm_layer = BatchNorm, revnorm = false,
4-
drop_block = identity, drop_path = identity,
5-
attn_fn = planes -> identity)
2+
basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3+
reduction_factor::Integer = 1, activation = relu,
4+
norm_layer = BatchNorm, revnorm::Bool = false,
5+
drop_block = identity, drop_path = identity,
6+
attn_fn = planes -> identity)
67
78
Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
89
@@ -11,10 +12,11 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385
1112
- `inplanes`: number of input feature maps
1213
- `planes`: number of feature maps for the block
1314
- `stride`: the stride of the block
14-
- `reduction_factor`: the factor by which the input feature maps
15-
are reduced before the first convolution.
15+
- `reduction_factor`: the factor by which the input feature maps are reduced before
16+
the first convolution.
1617
- `activation`: the activation function to use.
1718
- `norm_layer`: the normalization layer to use.
19+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
1820
- `drop_block`: the drop block layer
1921
- `drop_path`: the drop path layer
2022
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -36,11 +38,12 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3638
end
3739

3840
"""
39-
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
40-
reduction_factor = 1, activation = relu,
41-
norm_layer = BatchNorm, revnorm = false,
42-
drop_block = identity, drop_path = identity,
43-
attn_fn = planes -> identity)
41+
bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
42+
cardinality::Integer = 1, base_width::Integer = 64,
43+
reduction_factor::Integer = 1, activation = relu,
44+
norm_layer = BatchNorm, revnorm::Bool = false,
45+
drop_block = identity, drop_path = identity,
46+
attn_fn = planes -> identity)
4447
4548
Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
4649
@@ -55,6 +58,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.
5558
convolution.
5659
- `activation`: the activation function to use.
5760
- `norm_layer`: the normalization layer to use.
61+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
5862
- `drop_block`: the drop block layer
5963
- `drop_path`: the drop path layer
6064
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -153,7 +157,8 @@ on how to use this function.
153157
shows peformance improvements over the `:deep` stem in some cases.
154158
155159
- `inchannels`: The number of channels in the input.
156-
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
160+
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution +
161+
normalization with a stride of two.
157162
- `norm_layer`: The normalisation layer used in the stem.
158163
- `activation`: The activation function used in the stem.
159164
"""
@@ -253,8 +258,6 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
253258
stride = stride_fn(stage_idx, block_idx)
254259
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
255260
downsample_tuple[1] : downsample_tuple[2]
256-
# DropBlock, DropPath both take in rates based on a linear scaling schedule
257-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
258261
drop_path = DropPath(pathschedule[schedule_idx])
259262
drop_block = DropBlock(blockschedule[schedule_idx])
260263
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
@@ -280,8 +283,7 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
280283
end
281284

282285
function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
283-
connection,
284-
classifier_fn)
286+
connection, classifier_fn)
285287
# Build stages of the ResNet
286288
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
287289
backbone = Chain(stem, stage_blocks)
@@ -291,35 +293,46 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte
291293
return Chain(backbone, classifier)
292294
end
293295

294-
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};
295-
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity),
296+
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
297+
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
296298
cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64,
297299
reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256),
298-
inchannels::Integer = 3, stem_fn = resnet_stem,
299-
connection = addact, activation = relu, norm_layer = BatchNorm,
300-
revnorm::Bool = false, attn_fn = planes -> identity,
301-
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
302-
drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0,
303-
nclasses::Integer = 1000)
300+
inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact,
301+
activation = relu, norm_layer = BatchNorm, revnorm::Bool = false,
302+
attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)),
303+
use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0,
304+
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
304305
# Build stem
305306
stem = stem_fn(; inchannels)
306307
# Block builder
307-
if block_type == :basicblock
308+
if block_type === :basicblock
308309
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
309310
@assert base_width==64 "Base width must be 64 for `basicblock`"
310311
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
311312
activation, norm_layer, revnorm, attn_fn,
312313
drop_block_rate, drop_path_rate,
313314
stride_fn = resnet_stride,
314315
planes_fn = resnet_planes,
315-
downsample_tuple = downsample_opt)
316-
elseif block_type == :bottleneck
316+
downsample_tuple = downsample_opt,
317+
kwargs...)
318+
elseif block_type === :bottleneck
317319
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
318-
reduction_factor, activation, norm_layer,
319-
revnorm, attn_fn, drop_block_rate, drop_path_rate,
320+
reduction_factor, activation, norm_layer, revnorm,
321+
attn_fn, drop_block_rate, drop_path_rate,
320322
stride_fn = resnet_stride,
321323
planes_fn = resnet_planes,
322-
downsample_tuple = downsample_opt)
324+
downsample_tuple = downsample_opt,
325+
kwargs...)
326+
elseif block_type === :bottle2neck
327+
@assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`"
328+
@assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`"
329+
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`"
330+
get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width,
331+
activation, norm_layer, revnorm, attn_fn,
332+
stride_fn = resnet_stride,
333+
planes_fn = resnet_planes,
334+
downsample_tuple = downsample_opt,
335+
kwargs...)
323336
else
324337
# TODO: write better message when we have link to dev docs for resnet
325338
throw(ArgumentError("Unknown block type $block_type"))

src/convnets/resnets/res2net.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
3+
cardinality::Integer = 1, base_width::Integer = 26,
4+
scale::Integer = 4, activation = relu, norm_layer = BatchNorm,
5+
revnorm::Bool = false, attn_fn = planes -> identity)
6+
7+
Creates a bottleneck block as described in the Res2Net paper.
8+
([reference](https://arxiv.org/abs/1904.01169))
9+
10+
# Arguments
11+
- `inplanes`: number of input feature maps
12+
- `planes`: number of feature maps for the block
13+
- `stride`: the stride of the block
14+
- `cardinality`: the number of groups in the 3x3 convolutions.
15+
- `base_width`: the number of output feature maps for each convolutional group.
16+
- `scale`: the number of feature groups in the block. See the [paper](https://arxiv.org/abs/1904.01169)
17+
for more details.
18+
- `activation`: the activation function to use.
19+
- `norm_layer`: the normalization layer to use.
20+
- `revnorm`: set to `true` to place the batch norm before the convolution
21+
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
22+
"""
23+
function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
24+
cardinality::Integer = 1, base_width::Integer = 26,
25+
scale::Integer = 4, activation = relu, norm_layer = BatchNorm,
26+
revnorm::Bool = false, attn_fn = planes -> identity)
27+
width = fld(planes * base_width, 64) * cardinality
28+
outplanes = planes * 4
29+
is_first = stride > 1
30+
pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity
31+
conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride,
32+
pad = 1, groups = cardinality, bias = false)...)
33+
for _ in 1:(max(1, scale - 1))]
34+
reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) :
35+
Parallel(cat_channels, identity, PairwiseFusion(+, conv_bns...))
36+
tuplify(x) = is_first ? tuple(x...) : tuple(x[1], tuple(x[2:end]...))
37+
return Chain(conv_norm((1, 1), inplanes => width * scale, activation;
38+
norm_layer, revnorm, bias = false)...,
39+
chunk$(; size = width, dims = 3),
40+
tuplify, reslayer,
41+
conv_norm((1, 1), width * scale => outplanes, activation;
42+
norm_layer, revnorm, bias = false)...,
43+
attn_fn(outplanes))
44+
end
45+
46+
function bottle2neck_builder(block_repeats::AbstractVector{<:Integer};
47+
inplanes::Integer = 64, cardinality::Integer = 1,
48+
base_width::Integer = 26, scale::Integer = 4,
49+
expansion::Integer = 4, norm_layer = BatchNorm,
50+
revnorm::Bool = false, activation = relu,
51+
attn_fn = planes -> identity,
52+
stride_fn = resnet_stride, planes_fn = resnet_planes,
53+
downsample_tuple = (downsample_conv, downsample_identity))
54+
planes_vec = collect(planes_fn(block_repeats))
55+
# closure over `idxs`
56+
function get_layers(stage_idx::Integer, block_idx::Integer)
57+
# This is needed for block `inplanes` and `planes` calculations
58+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
59+
planes = planes_vec[schedule_idx]
60+
inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion
61+
# `resnet_stride` is a callback that the user can tweak to change the stride of the
62+
# blocks. It defaults to the standard behaviour as in the paper
63+
stride = stride_fn(stage_idx, block_idx)
64+
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
65+
downsample_tuple[1] : downsample_tuple[2]
66+
block = bottle2neck(inplanes, planes; stride, cardinality, base_width, scale,
67+
activation, norm_layer, revnorm, attn_fn)
68+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
69+
revnorm)
70+
return block, downsample
71+
end
72+
return get_layers
73+
end
74+
75+
"""
76+
Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
77+
base_width::Integer = 26, inchannels::Integer = 3,
78+
nclasses::Integer = 1000)
79+
80+
Creates a Res2Net model with the specified depth, scale, and base width.
81+
([reference](https://arxiv.org/abs/1904.01169))
82+
83+
# Arguments
84+
- `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model.
85+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
86+
- `scale`: the number of feature groups in the block. See the
87+
[paper](https://arxiv.org/abs/1904.01169) for more details.
88+
- `base_width`: the number of feature maps in each group.
89+
- `inchannels`: the number of input channels.
90+
- `nclasses`: the number of output classes
91+
"""
92+
struct Res2Net
93+
layers::Any
94+
end
95+
96+
function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
97+
base_width::Integer = 26, inchannels::Integer = 3,
98+
nclasses::Integer = 1000)
99+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
100+
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2], :C; base_width, scale,
101+
inchannels, nclasses)
102+
if pretrain
103+
loadpretrain!(layers, string("Res2Net", depth, "_", base_width, "x", scale))
104+
end
105+
return ResNet(layers)
106+
end
107+
108+
"""
109+
Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
110+
base_width::Integer = 4, cardinality::Integer = 8,
111+
inchannels::Integer = 3, nclasses::Integer = 1000)
112+
113+
Creates a Res2NeXt model with the specified depth, scale, base width and cardinality.
114+
([reference](https://arxiv.org/abs/1904.01169))
115+
116+
# Arguments
117+
- `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model.
118+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
119+
- `scale`: the number of feature groups in the block. See the
120+
[paper](https://arxiv.org/abs/1904.01169) for more details.
121+
- `base_width`: the number of feature maps in each group.
122+
- `cardinality`: the number of groups in the 3x3 convolutions.
123+
- `inchannels`: the number of input channels.
124+
- `nclasses`: the number of output classes
125+
"""
126+
struct Res2NeXt
127+
layers::Any
128+
end
129+
130+
function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
131+
base_width::Integer = 4, cardinality::Integer = 8,
132+
inchannels::Integer = 3, nclasses::Integer = 1000)
133+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
134+
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2], :C; base_width, scale,
135+
cardinality, inchannels, nclasses)
136+
if pretrain
137+
loadpretrain!(layers,
138+
string("Res2NeXt", depth, "_", base_width, "x", cardinality,
139+
"x", scale))
140+
end
141+
return ResNet(layers)
142+
end

src/convnets/resnets/resnet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565

6666
function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
6767
nclasses::Integer = 1000)
68-
_checkconfig(depth, [50, 101])
68+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
6969
layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
7070
if pretrain
7171
loadpretrain!(layers, string("WideResNet", depth))

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Create a convolution + batch normalization pair with activation.
1616
- `outplanes`: number of output feature maps
1717
- `activation`: the activation function for the final layer
1818
- `norm_layer`: the normalization layer used
19-
- `revnorm`: set to `true` to place the batch norm before the convolution
19+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
2020
- `preact`: set to `true` to place the activation function before the batch norm
2121
(only compatible with `revnorm = false`)
2222
- `use_norm`: set to `false` to disable normalization

src/mixers/mlpmixer.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ struct MLPMixer
5656
end
5757
@functor MLPMixer
5858

59-
function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16),
59+
function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224),
60+
patch_size::Dims{2} = (16, 16),
6061
inchannels::Integer = 3, nclasses::Integer = 1000)
6162
_checkconfig(config, keys(MIXER_CONFIGS))
6263
layers = mlpmixer(mixerblock, imsize; patch_size, MIXER_CONFIGS[config]..., inchannels,

0 commit comments

Comments
 (0)