Skip to content

Commit 842fa99

Browse files
authored
Merge pull request #195 from theabhirath/res2net-again
Res2Net and Res2NeXt, again
2 parents 1e4c669 + 76d5b7e commit 842fa99

File tree

15 files changed

+278
-57
lines changed

15 files changed

+278
-57
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- '["EfficientNet"]'
3232
- 'r"/*/ResNet*"'
3333
- '[r"ResNeXt", r"SEResNet"]'
34+
- '[r"Res2Net", r"Res2NeXt"]'
3435
- '"Inception"'
3536
- '"DenseNet"'
3637
- '["ConvNeXt", "ConvMixer"]'

src/Metalhead.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ include("convnets/resnets/core.jl")
2626
include("convnets/resnets/resnet.jl")
2727
include("convnets/resnets/resnext.jl")
2828
include("convnets/resnets/seresnet.jl")
29+
include("convnets/resnets/res2net.jl")
2930
## Inceptions
3031
include("convnets/inception/googlenet.jl")
3132
include("convnets/inception/inceptionv3.jl")
@@ -57,16 +58,16 @@ include("pretrain.jl")
5758

5859
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
5960
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
60-
WideResNet, ResNeXt, SEResNet, SEResNeXt,
61+
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
6162
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6263
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
6364
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
6465
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt
6566

6667
# use Flux._big_show to pretty print large models
6768
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
68-
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
69-
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
69+
:Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4,
70+
:Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
7071
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
7172
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
7273
end

src/convnets/inception/googlenet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Create an Inception-v1 model (commonly referred to as GoogLeNet)
3636
3737
- `nclasses`: the number of output classes
3838
"""
39-
function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
39+
function googlenet(; dropout_rate = 0.4, inchannels::Integer = 3, nclasses::Integer = 1000)
4040
backbone = Chain(Conv((7, 7), inchannels => 64; stride = 2, pad = 3),
4141
MaxPool((3, 3); stride = 2, pad = 1),
4242
Conv((1, 1), 64 => 64),
@@ -53,7 +53,7 @@ function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
5353
MaxPool((3, 3); stride = 2, pad = 1),
5454
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
5555
_inceptionblock(832, 384, 192, 384, 48, 128, 128))
56-
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate = 0.4))
56+
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate))
5757
end
5858

5959
"""

src/convnets/inception/inceptionresnetv2.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0,
9696
end
9797

9898
"""
99-
InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
99+
InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3,
100+
nclasses::Integer = 1000)
100101
101102
Creates an InceptionResNetv2 model.
102103
([reference](https://arxiv.org/abs/1602.07261))
@@ -118,9 +119,8 @@ end
118119
@functor InceptionResNetv2
119120

120121
function InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3,
121-
dropout_rate = 0.0,
122122
nclasses::Integer = 1000)
123-
layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses)
123+
layers = inceptionresnetv2(; inchannels, nclasses)
124124
if pretrain
125125
loadpretrain!(layers, "InceptionResNetv2")
126126
end

src/convnets/inception/inceptionv3.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
133133
134134
- `nclasses`: the number of output classes
135135
"""
136-
function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
136+
function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000)
137137
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
138138
conv_norm((3, 3), 32, 32)...,
139139
conv_norm((3, 3), 32, 64; pad = 1)...,
@@ -152,7 +152,7 @@ function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
152152
inceptionv3_d(768),
153153
inceptionv3_e(1280),
154154
inceptionv3_e(2048))
155-
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate = 0.2))
155+
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate))
156156
end
157157

158158
"""

src/convnets/inception/xception.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ
6666
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
6767
depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
6868
depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
69-
classifier = create_classifier(2048, nclasses; dropout_rate)
70-
return Chain(backbone, classifier)
69+
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate))
7170
end
7271

7372
"""

src/convnets/resnets/core.jl

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
2-
basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
3-
norm_layer = BatchNorm, revnorm = false,
4-
drop_block = identity, drop_path = identity,
5-
attn_fn = planes -> identity)
2+
basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3+
reduction_factor::Integer = 1, activation = relu,
4+
norm_layer = BatchNorm, revnorm::Bool = false,
5+
drop_block = identity, drop_path = identity,
6+
attn_fn = planes -> identity)
67
78
Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
89
@@ -11,10 +12,11 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385
1112
- `inplanes`: number of input feature maps
1213
- `planes`: number of feature maps for the block
1314
- `stride`: the stride of the block
14-
- `reduction_factor`: the factor by which the input feature maps
15-
are reduced before the first convolution.
15+
- `reduction_factor`: the factor by which the input feature maps are reduced before
16+
the first convolution.
1617
- `activation`: the activation function to use.
1718
- `norm_layer`: the normalization layer to use.
19+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
1820
- `drop_block`: the drop block layer
1921
- `drop_path`: the drop path layer
2022
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -36,11 +38,12 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3638
end
3739

3840
"""
39-
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
40-
reduction_factor = 1, activation = relu,
41-
norm_layer = BatchNorm, revnorm = false,
42-
drop_block = identity, drop_path = identity,
43-
attn_fn = planes -> identity)
41+
bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
42+
cardinality::Integer = 1, base_width::Integer = 64,
43+
reduction_factor::Integer = 1, activation = relu,
44+
norm_layer = BatchNorm, revnorm::Bool = false,
45+
drop_block = identity, drop_path = identity,
46+
attn_fn = planes -> identity)
4447
4548
Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
4649
@@ -55,6 +58,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.
5558
convolution.
5659
- `activation`: the activation function to use.
5760
- `norm_layer`: the normalization layer to use.
61+
- `revnorm`: set to `true` to place the normalisation layer before the convolution
5862
- `drop_block`: the drop block layer
5963
- `drop_path`: the drop path layer
6064
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -112,7 +116,7 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
112116
end
113117

114118
# Shortcut configurations for the ResNet models
115-
const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
119+
const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity),
116120
:B => (downsample_conv, downsample_identity),
117121
:C => (downsample_conv, downsample_conv),
118122
:D => (downsample_pool, downsample_identity))
@@ -153,7 +157,8 @@ on how to use this function.
153157
shows peformance improvements over the `:deep` stem in some cases.
154158
155159
- `inchannels`: The number of channels in the input.
156-
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
160+
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution +
161+
normalization with a stride of two.
157162
- `norm_layer`: The normalisation layer used in the stem.
158163
- `activation`: The activation function used in the stem.
159164
"""
@@ -253,8 +258,6 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
253258
stride = stride_fn(stage_idx, block_idx)
254259
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
255260
downsample_tuple[1] : downsample_tuple[2]
256-
# DropBlock, DropPath both take in rates based on a linear scaling schedule
257-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
258261
drop_path = DropPath(pathschedule[schedule_idx])
259262
drop_block = DropBlock(blockschedule[schedule_idx])
260263
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
@@ -289,35 +292,46 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte
289292
return Chain(backbone, classifier_fn(nfeaturemaps))
290293
end
291294

292-
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};
293-
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity),
295+
function resnet(block_type, block_repeats::AbstractVector{<:Integer},
296+
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
294297
cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64,
295298
reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256),
296-
inchannels::Integer = 3, stem_fn = resnet_stem,
297-
connection = addact, activation = relu, norm_layer = BatchNorm,
298-
revnorm::Bool = false, attn_fn = planes -> identity,
299-
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
300-
drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0,
301-
nclasses::Integer = 1000)
299+
inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact,
300+
activation = relu, norm_layer = BatchNorm, revnorm::Bool = false,
301+
attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)),
302+
use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0,
303+
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
302304
# Build stem
303305
stem = stem_fn(; inchannels)
304306
# Block builder
305-
if block_type == :basicblock
307+
if block_type == basicblock
306308
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
307309
@assert base_width==64 "Base width must be 64 for `basicblock`"
308310
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
309311
activation, norm_layer, revnorm, attn_fn,
310312
drop_block_rate, drop_path_rate,
311313
stride_fn = resnet_stride,
312314
planes_fn = resnet_planes,
313-
downsample_tuple = downsample_opt)
314-
elseif block_type == :bottleneck
315+
downsample_tuple = downsample_opt,
316+
kwargs...)
317+
elseif block_type == bottleneck
315318
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
316-
reduction_factor, activation, norm_layer,
317-
revnorm, attn_fn, drop_block_rate, drop_path_rate,
319+
reduction_factor, activation, norm_layer, revnorm,
320+
attn_fn, drop_block_rate, drop_path_rate,
318321
stride_fn = resnet_stride,
319322
planes_fn = resnet_planes,
320-
downsample_tuple = downsample_opt)
323+
downsample_tuple = downsample_opt,
324+
kwargs...)
325+
elseif block_type == bottle2neck
326+
@assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`"
327+
@assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`"
328+
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`"
329+
get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width,
330+
activation, norm_layer, revnorm, attn_fn,
331+
stride_fn = resnet_stride,
332+
planes_fn = resnet_planes,
333+
downsample_tuple = downsample_opt,
334+
kwargs...)
321335
else
322336
# TODO: write better message when we have link to dev docs for resnet
323337
throw(ArgumentError("Unknown block type $block_type"))
@@ -328,12 +342,16 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};
328342
connection$activation, classifier_fn)
329343
end
330344
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
331-
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)
345+
return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...)
332346
end
333347

334348
# block-layer configurations for ResNet-like models
335-
const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
336-
34 => (:basicblock, [3, 4, 6, 3]),
337-
50 => (:bottleneck, [3, 4, 6, 3]),
338-
101 => (:bottleneck, [3, 4, 23, 3]),
339-
152 => (:bottleneck, [3, 8, 36, 3]))
349+
const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]),
350+
34 => (basicblock, [3, 4, 6, 3]),
351+
50 => (bottleneck, [3, 4, 6, 3]),
352+
101 => (bottleneck, [3, 4, 23, 3]),
353+
152 => (bottleneck, [3, 8, 36, 3]))
354+
355+
const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]),
356+
101 => (bottleneck, [3, 4, 23, 3]),
357+
152 => (bottleneck, [3, 8, 36, 3]))

0 commit comments

Comments
 (0)