Skip to content

Commit d3e4add

Browse files
committed
Refactor: mbconv instead of invertedresidual
Also fix bug in EfficientNet models
1 parent 7d56396 commit d3e4add

File tree

9 files changed

+178
-145
lines changed

9 files changed

+178
-145
lines changed

src/convnets/efficientnet/efficientnet.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,37 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
1717
+ `e`: expansion ratio
1818
+ `i`: block input channels (will be scaled by global width scaling)
1919
+ `o`: block output channels (will be scaled by global width scaling)
20-
- `max_width`: The maximum number of feature maps in any layer of the network
2120
- `inchannels`: number of input channels
2221
- `nclasses`: number of output classes
2322
"""
2423
function efficientnet(scalings::NTuple{2, Real},
2524
block_configs::AbstractVector{NTuple{6, Int}};
26-
max_width::Integer = 1280, inchannels::Integer = 3,
27-
nclasses::Integer = 1000)
25+
inchannels::Integer = 3, nclasses::Integer = 1000)
26+
# building first layer
2827
wscale, dscale = scalings
2928
scalew(w) = wscale 1 ? w : ceil(Int64, wscale * w)
3029
scaled(d) = dscale 1 ? d : ceil(Int64, dscale * d)
3130
outplanes = _round_channels(scalew(32), 8)
3231
stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2,
3332
pad = SamePad())
33+
# building inverted residual blocks
3434
blocks = []
3535
for (n, k, s, e, i, o) in block_configs
3636
inchannels = _round_channels(scalew(i), 8)
37+
explanes = _round_channels(inchannels * e, 8)
3738
outplanes = _round_channels(scalew(o), 8)
3839
repeats = scaled(n)
3940
push!(blocks,
40-
invertedresidual((k, k), inchannels, outplanes, swish; expansion = e,
41-
stride = s, reduction = 4))
41+
mbconv((k, k), inchannels, explanes, outplanes, swish;
42+
stride = s, reduction = 4))
4243
for _ in 1:(repeats - 1)
4344
push!(blocks,
44-
invertedresidual((k, k), outplanes, outplanes, swish; expansion = e,
45-
stride = 1, reduction = 4))
45+
mbconv((k, k), outplanes, explanes, outplanes, swish;
46+
stride = 1, reduction = 4))
4647
end
4748
end
48-
headplanes = _round_channels(max_width, 8)
49+
# building last layers
50+
headplanes = outplanes * 4
4951
append!(blocks,
5052
conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad()))
5153
return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses))

src/convnets/efficientnet/efficientnetv2.jl

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,19 @@ function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 17
3232
conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2,
3333
bias = false))
3434
# building inverted residual blocks
35-
for (t, c, n, s, reduction) in config
36-
outplanes = _round_channels(c * width_mult, 8)
35+
for (t, inplanes, outplanes, n, s, reduction) in config
36+
explanes = _round_channels(inplanes * t, 8)
3737
for i in 1:n
38-
push!(layers,
39-
invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t,
40-
stride = i == 1 ? s : 1, reduction))
38+
stride = i == 1 ? s : 1
39+
if isnothing(reduction)
40+
push!(layers,
41+
fused_mbconv((3, 3), inplanes, explanes, outplanes, swish; stride))
42+
else
43+
inplanes = _round_channels(inplanes * width_mult, 8)
44+
outplanes = _round_channels(outplanes * width_mult, 8)
45+
push!(layers,
46+
mbconv((3, 3), inplanes, explanes, outplanes, swish; stride))
47+
end
4148
inplanes = outplanes
4249
end
4350
end
@@ -49,37 +56,37 @@ function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 17
4956
end
5057

5158
# block configs for EfficientNetv2
52-
const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, r
53-
(1, 24, 2, 1, nothing),
54-
(4, 48, 4, 2, nothing),
55-
(4, 64, 4, 2, nothing),
56-
(4, 128, 6, 2, 4),
57-
(6, 160, 9, 1, 4),
58-
(6, 256, 15, 2, 4)],
59-
:medium => [# t, c, n, s, r
60-
(1, 24, 3, 1, nothing),
61-
(4, 48, 5, 2, nothing),
62-
(4, 80, 5, 2, nothing),
63-
(4, 160, 7, 2, 4),
64-
(6, 176, 14, 1, 4),
65-
(6, 304, 18, 2, 4),
66-
(6, 512, 5, 1, 4)],
67-
:large => [# t, c, n, s, r
68-
(1, 32, 4, 1, nothing),
69-
(4, 64, 8, 2, nothing),
70-
(4, 96, 8, 2, nothing),
71-
(4, 192, 16, 2, 4),
72-
(6, 256, 24, 1, 4),
73-
(6, 512, 32, 2, 4),
74-
(6, 640, 8, 1, 4)],
75-
:xlarge => [# t, c, n, s, r
76-
(1, 32, 4, 1, nothing),
77-
(4, 64, 8, 2, nothing),
78-
(4, 96, 8, 2, nothing),
79-
(4, 192, 16, 2, 4),
80-
(6, 256, 24, 1, 4),
81-
(6, 512, 32, 2, 4),
82-
(6, 640, 8, 1, 4)])
59+
const EFFNETV2_CONFIGS = Dict(:small => [
60+
(1, 24, 24, 2, 1, nothing),
61+
(4, 24, 48, 4, 2, nothing),
62+
(4, 48, 64, 4, 2, nothing),
63+
(4, 64, 128, 6, 2, 4),
64+
(6, 128, 160, 9, 1, 4),
65+
(6, 160, 256, 15, 2, 4)],
66+
:medium => [
67+
(1, 24, 24, 3, 1, nothing),
68+
(4, 24, 48, 5, 2, nothing),
69+
(4, 48, 80, 5, 2, nothing),
70+
(4, 80, 160, 7, 2, 4),
71+
(6, 160, 176, 14, 1, 4),
72+
(6, 176, 304, 18, 2, 4),
73+
(6, 304, 512, 5, 1, 4)],
74+
:large => [
75+
(1, 32, 32, 4, 1, nothing),
76+
(4, 32, 64, 7, 2, nothing),
77+
(4, 64, 96, 7, 2, nothing),
78+
(4, 96, 192, 10, 2, 4),
79+
(6, 192, 224, 19, 1, 4),
80+
(6, 224, 384, 25, 2, 4),
81+
(6, 384, 640, 7, 1, 4)],
82+
:xlarge => [
83+
(1, 32, 32, 4, 1, nothing),
84+
(4, 32, 64, 8, 2, nothing),
85+
(4, 64, 96, 8, 2, nothing),
86+
(4, 96, 192, 16, 2, 4),
87+
(6, 192, 256, 24, 1, 4),
88+
(6, 256, 512, 32, 2, 4),
89+
(6, 512, 640, 8, 1, 4)])
8390

8491
"""
8592
EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1,

src/convnets/inception/inceptionresnetv2.jl

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,64 @@
11
function mixed_5b()
2-
branch1 = Chain(conv_norm((1, 1), 192, 96)...)
3-
branch2 = Chain(conv_norm((1, 1), 192, 48)...,
4-
conv_norm((5, 5), 48, 64; pad = 2)...)
5-
branch3 = Chain(conv_norm((1, 1), 192, 64)...,
6-
conv_norm((3, 3), 64, 96; pad = 1)...,
7-
conv_norm((3, 3), 96, 96; pad = 1)...)
2+
branch1 = Chain(basic_conv_bn((1, 1), 192, 96)...)
3+
branch2 = Chain(basic_conv_bn((1, 1), 192, 48)...,
4+
basic_conv_bn((5, 5), 48, 64; pad = 2)...)
5+
branch3 = Chain(basic_conv_bn((1, 1), 192, 64)...,
6+
basic_conv_bn((3, 3), 64, 96; pad = 1)...,
7+
basic_conv_bn((3, 3), 96, 96; pad = 1)...)
88
branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1),
9-
conv_norm((1, 1), 192, 64)...)
9+
basic_conv_bn((1, 1), 192, 64)...)
1010
return Parallel(cat_channels, branch1, branch2, branch3, branch4)
1111
end
1212

1313
function block35(scale = 1.0f0)
14-
branch1 = Chain(conv_norm((1, 1), 320, 32)...)
15-
branch2 = Chain(conv_norm((1, 1), 320, 32)...,
16-
conv_norm((3, 3), 32, 32; pad = 1)...)
17-
branch3 = Chain(conv_norm((1, 1), 320, 32)...,
18-
conv_norm((3, 3), 32, 48; pad = 1)...,
19-
conv_norm((3, 3), 48, 64; pad = 1)...)
20-
branch4 = Chain(conv_norm((1, 1), 128, 320)...)
14+
branch1 = Chain(basic_conv_bn((1, 1), 320, 32)...)
15+
branch2 = Chain(basic_conv_bn((1, 1), 320, 32)...,
16+
basic_conv_bn((3, 3), 32, 32; pad = 1)...)
17+
branch3 = Chain(basic_conv_bn((1, 1), 320, 32)...,
18+
basic_conv_bn((3, 3), 32, 48; pad = 1)...,
19+
basic_conv_bn((3, 3), 48, 64; pad = 1)...)
20+
branch4 = Chain(basic_conv_bn((1, 1), 128, 320)...)
2121
return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3),
2222
branch4, inputscale(scale; activation = relu)), +)
2323
end
2424

2525
function mixed_6a()
26-
branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...)
27-
branch2 = Chain(conv_norm((1, 1), 320, 256)...,
28-
conv_norm((3, 3), 256, 256; pad = 1)...,
29-
conv_norm((3, 3), 256, 384; stride = 2)...)
26+
branch1 = Chain(basic_conv_bn((3, 3), 320, 384; stride = 2)...)
27+
branch2 = Chain(basic_conv_bn((1, 1), 320, 256)...,
28+
basic_conv_bn((3, 3), 256, 256; pad = 1)...,
29+
basic_conv_bn((3, 3), 256, 384; stride = 2)...)
3030
branch3 = MaxPool((3, 3); stride = 2)
3131
return Parallel(cat_channels, branch1, branch2, branch3)
3232
end
3333

3434
function block17(scale = 1.0f0)
35-
branch1 = Chain(conv_norm((1, 1), 1088, 192)...)
36-
branch2 = Chain(conv_norm((1, 1), 1088, 128)...,
37-
conv_norm((7, 1), 128, 160; pad = (3, 0))...,
38-
conv_norm((1, 7), 160, 192; pad = (0, 3))...)
39-
branch3 = Chain(conv_norm((1, 1), 384, 1088)...)
35+
branch1 = Chain(basic_conv_bn((1, 1), 1088, 192)...)
36+
branch2 = Chain(basic_conv_bn((1, 1), 1088, 128)...,
37+
basic_conv_bn((7, 1), 128, 160; pad = (3, 0))...,
38+
basic_conv_bn((1, 7), 160, 192; pad = (0, 3))...)
39+
branch3 = Chain(basic_conv_bn((1, 1), 384, 1088)...)
4040
return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2),
4141
branch3, inputscale(scale; activation = relu)), +)
4242
end
4343

4444
function mixed_7a()
45-
branch1 = Chain(conv_norm((1, 1), 1088, 256)...,
46-
conv_norm((3, 3), 256, 384; stride = 2)...)
47-
branch2 = Chain(conv_norm((1, 1), 1088, 256)...,
48-
conv_norm((3, 3), 256, 288; stride = 2)...)
49-
branch3 = Chain(conv_norm((1, 1), 1088, 256)...,
50-
conv_norm((3, 3), 256, 288; pad = 1)...,
51-
conv_norm((3, 3), 288, 320; stride = 2)...)
45+
branch1 = Chain(basic_conv_bn((1, 1), 1088, 256)...,
46+
basic_conv_bn((3, 3), 256, 384; stride = 2)...)
47+
branch2 = Chain(basic_conv_bn((1, 1), 1088, 256)...,
48+
basic_conv_bn((3, 3), 256, 288; stride = 2)...)
49+
branch3 = Chain(basic_conv_bn((1, 1), 1088, 256)...,
50+
basic_conv_bn((3, 3), 256, 288; pad = 1)...,
51+
basic_conv_bn((3, 3), 288, 320; stride = 2)...)
5252
branch4 = MaxPool((3, 3); stride = 2)
5353
return Parallel(cat_channels, branch1, branch2, branch3, branch4)
5454
end
5555

5656
function block8(scale = 1.0f0; activation = identity)
57-
branch1 = Chain(conv_norm((1, 1), 2080, 192)...)
58-
branch2 = Chain(conv_norm((1, 1), 2080, 192)...,
59-
conv_norm((3, 1), 192, 224; pad = (1, 0))...,
60-
conv_norm((1, 3), 224, 256; pad = (0, 1))...)
61-
branch3 = Chain(conv_norm((1, 1), 448, 2080)...)
57+
branch1 = Chain(basic_conv_bn((1, 1), 2080, 192)...)
58+
branch2 = Chain(basic_conv_bn((1, 1), 2080, 192)...,
59+
basic_conv_bn((3, 1), 192, 224; pad = (1, 0))...,
60+
basic_conv_bn((1, 3), 224, 256; pad = (0, 1))...)
61+
branch3 = Chain(basic_conv_bn((1, 1), 448, 2080)...)
6262
return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2),
6363
branch3, inputscale(scale; activation)), +)
6464
end
@@ -77,12 +77,12 @@ Creates an InceptionResNetv2 model.
7777
"""
7878
function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3,
7979
nclasses::Integer = 1000)
80-
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
81-
conv_norm((3, 3), 32, 32)...,
82-
conv_norm((3, 3), 32, 64; pad = 1)...,
80+
backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)...,
81+
basic_conv_bn((3, 3), 32, 32)...,
82+
basic_conv_bn((3, 3), 32, 64; pad = 1)...,
8383
MaxPool((3, 3); stride = 2),
84-
conv_norm((3, 3), 64, 80)...,
85-
conv_norm((3, 3), 80, 192)...,
84+
basic_conv_bn((3, 3), 64, 80)...,
85+
basic_conv_bn((3, 3), 80, 192)...,
8686
MaxPool((3, 3); stride = 2),
8787
mixed_5b(),
8888
[block35(0.17f0) for _ in 1:10]...,
@@ -91,7 +91,7 @@ function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3,
9191
mixed_7a(),
9292
[block8(0.20f0) for _ in 1:9]...,
9393
block8(; activation = relu),
94-
conv_norm((1, 1), 2080, 1536)...)
94+
basic_conv_bn((1, 1), 2080, 1536)...)
9595
return Chain(backbone, create_classifier(1536, nclasses; dropout_rate))
9696
end
9797

src/convnets/mobilenet/mobilenetv2.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
3737
outplanes = _round_channels(c * width_mult, divisor)
3838
for i in 1:n
3939
push!(layers,
40-
invertedresidual((3, 3), inplanes, outplanes, a; expansion = t,
41-
stride = i == 1 ? s : 1))
40+
mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a;
41+
stride = i == 1 ? s : 1))
4242
inplanes = outplanes
4343
end
4444
end

src/convnets/mobilenet/mobilenetv3.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Create a MobileNetv3 model.
2424
- `nclasses`: the number of output classes
2525
"""
2626
function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
27-
max_width::Integer = 1024, inchannels::Integer = 3,
28-
nclasses::Integer = 1000)
27+
max_width::Integer = 1024, dropout_rate = 0.2,
28+
inchannels::Integer = 3, nclasses::Integer = 1000)
2929
# building first layer
3030
inplanes = _round_channels(16 * width_mult, 8)
3131
layers = []
@@ -39,8 +39,8 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
3939
outplanes = _round_channels(c * width_mult, 8)
4040
explanes = _round_channels(inplanes * t, 8)
4141
push!(layers,
42-
invertedresidual((k, k), inplanes, explanes, outplanes, activation;
43-
stride, reduction))
42+
mbconv((k, k), inplanes, explanes, outplanes, activation;
43+
stride, reduction))
4444
inplanes = outplanes
4545
end
4646
# building last layers
@@ -49,7 +49,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
4949
append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false))
5050
classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten,
5151
Dense(explanes, headplanes, hardswish),
52-
Dropout(0.2),
52+
Dropout(dropout_rate),
5353
Dense(headplanes, nclasses))
5454
return Chain(Chain(layers...), classifier)
5555
end

src/layers/Layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include("attention.jl")
1919
export MHAttention
2020

2121
include("conv.jl")
22-
export conv_norm, basic_conv_bn, dwsep_conv_bn, invertedresidual
22+
export conv_norm, basic_conv_bn, dwsep_conv_bn, mbconv, fused_mbconv
2323

2424
include("drop.jl")
2525
export DropBlock, DropPath

0 commit comments

Comments
 (0)