Skip to content

Commit 5aece44

Browse files
committed
Use create_classifier more
1 parent 4e46d7b commit 5aece44

19 files changed

+175
-196
lines changed

src/convnets/alexnet.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,21 @@ Create an AlexNet model
99
- `nclasses`: the number of output classes
1010
"""
1111
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
12-
layers = Chain(Chain(Conv((11, 11), inchannels => 64, relu; stride = (4, 4), pad = (2, 2)),
13-
MaxPool((3, 3); stride = (2, 2)),
14-
Conv((5, 5), 64 => 192, relu; pad = (2, 2)),
15-
MaxPool((3, 3); stride = (2, 2)),
16-
Conv((3, 3), 192 => 384, relu; pad = (1, 1)),
17-
Conv((3, 3), 384 => 256, relu; pad = (1, 1)),
18-
Conv((3, 3), 256 => 256, relu; pad = (1, 1)),
19-
MaxPool((3, 3); stride = (2, 2)),
20-
AdaptiveMeanPool((6, 6))),
21-
Chain(MLUtils.flatten,
22-
Dropout(0.5),
23-
Dense(256 * 6 * 6, 4096, relu),
24-
Dropout(0.5),
25-
Dense(4096, 4096, relu),
26-
Dense(4096, nclasses)))
27-
return layers
12+
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
13+
MaxPool((3, 3); stride = 2),
14+
Conv((5, 5), 64 => 192, relu; pad = 2),
15+
MaxPool((3, 3); stride = 2),
16+
Conv((3, 3), 192 => 384, relu; pad = 1),
17+
Conv((3, 3), 384 => 256, relu; pad = 1),
18+
Conv((3, 3), 256 => 256, relu; pad = 1),
19+
MaxPool((3, 3); stride = 2))
20+
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
21+
Dropout(0.5),
22+
Dense(256 * 6 * 6, 4096, relu),
23+
Dropout(0.5),
24+
Dense(4096, 4096, relu),
25+
Dense(4096, nclasses))
26+
return Chain(backbone, classifier)
2827
end
2928

3029
"""
@@ -47,7 +46,8 @@ struct AlexNet
4746
end
4847
@functor AlexNet
4948

50-
function AlexNet(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
49+
function AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
50+
nclasses::Integer = 1000)
5151
layers = alexnet(; inchannels, nclasses)
5252
if pretrain
5353
loadpretrain!(layers, "AlexNet")

src/convnets/convmixer.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
2626
pad = SamePad())), +),
2727
conv_norm((1, 1), planes, planes, activation; preact = true)...)
2828
for _ in 1:depth]
29-
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
30-
return Chain(Chain(stem..., Chain(blocks)), head)
29+
return Chain(Chain(stem..., Chain(blocks)), create_classifier(planes, nclasses))
3130
end
3231

3332
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,

src/convnets/convnext.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
6363
cur += depths[i]
6464
end
6565
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
66-
head = Chain(GlobalMeanPool(),
67-
MLUtils.flatten,
68-
LayerNorm(planes[end]),
69-
Dense(planes[end], nclasses))
70-
return Chain(Chain(backbone), head)
66+
classifier = Chain(GlobalMeanPool(), MLUtils.flatten,
67+
LayerNorm(planes[end]),
68+
Dense(planes[end], nclasses))
69+
return Chain(Chain(backbone...), classifier)
7170
end
7271

7372
# Configurations for ConvNeXt models

src/convnets/densenet.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::
8383
inplanes = floor(Int, outplanes * reduction)
8484
end
8585
push!(layers, BatchNorm(outplanes, relu))
86-
return Chain(Chain(layers),
87-
Chain(AdaptiveMeanPool((1, 1)),
88-
MLUtils.flatten,
89-
Dense(outplanes, nclasses)))
86+
return Chain(Chain(layers...), create_classifier(outplanes, nclasses))
9087
end
9188

9289
"""

src/convnets/efficientnet.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ function efficientnet(scalings, block_configs; max_width::Integer = 1280,
2828
scalew(w) = wscale 1 ? w : ceil(Int64, wscale * w)
2929
scaled(d) = dscale 1 ? d : ceil(Int64, dscale * d)
3030
out_channels = _round_channels(scalew(32), 8)
31-
stem = conv_norm((3, 3), inchannels, out_channels, swish;
32-
bias = false, stride = 2, pad = SamePad())
31+
stem = conv_norm((3, 3), inchannels, out_channels, swish; bias = false, stride = 2,
32+
pad = SamePad())
3333
blocks = []
3434
for (n, k, s, e, i, o) in block_configs
3535
in_channels = _round_channels(scalew(i), 8)
@@ -44,13 +44,11 @@ function efficientnet(scalings, block_configs; max_width::Integer = 1280,
4444
stride = 1, reduction = 4))
4545
end
4646
end
47-
blocks = Chain(blocks...)
4847
head_out_channels = _round_channels(max_width, 8)
49-
head = conv_norm((1, 1), out_channels, head_out_channels, swish;
50-
bias = false, pad = SamePad())
51-
top = Dense(head_out_channels, nclasses)
52-
return Chain(Chain([stem..., blocks, head...]),
53-
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top))
48+
append!(blocks,
49+
conv_norm((1, 1), out_channels, head_out_channels, swish;
50+
bias = false, pad = SamePad()))
51+
return Chain(Chain(stem..., blocks...), create_classifier(head_out_channels, nclasses))
5452
end
5553

5654
# n: # of block repetitions

src/convnets/inception/googlenet.jl

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,29 @@ Create an Inception-v1 model (commonly referred to as GoogLeNet)
3636
3737
- `nclasses`: the number of output classes
3838
"""
39-
function googlenet(; nclasses::Integer = 1000)
40-
layers = Chain(Chain(Conv((7, 7), 3 => 64; stride = 2, pad = 3),
41-
MaxPool((3, 3); stride = 2, pad = 1),
42-
Conv((1, 1), 64 => 64),
43-
Conv((3, 3), 64 => 192; pad = 1),
44-
MaxPool((3, 3); stride = 2, pad = 1),
45-
_inceptionblock(192, 64, 96, 128, 16, 32, 32),
46-
_inceptionblock(256, 128, 128, 192, 32, 96, 64),
47-
MaxPool((3, 3); stride = 2, pad = 1),
48-
_inceptionblock(480, 192, 96, 208, 16, 48, 64),
49-
_inceptionblock(512, 160, 112, 224, 24, 64, 64),
50-
_inceptionblock(512, 128, 128, 256, 24, 64, 64),
51-
_inceptionblock(512, 112, 144, 288, 32, 64, 64),
52-
_inceptionblock(528, 256, 160, 320, 32, 128, 128),
53-
MaxPool((3, 3); stride = 2, pad = 1),
54-
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
55-
_inceptionblock(832, 384, 192, 384, 48, 128, 128)),
56-
Chain(AdaptiveMeanPool((1, 1)),
57-
MLUtils.flatten,
58-
Dropout(0.4),
59-
Dense(1024, nclasses)))
60-
return layers
39+
function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
40+
backbone = Chain(Conv((7, 7), inchannels => 64; stride = 2, pad = 3),
41+
MaxPool((3, 3); stride = 2, pad = 1),
42+
Conv((1, 1), 64 => 64),
43+
Conv((3, 3), 64 => 192; pad = 1),
44+
MaxPool((3, 3); stride = 2, pad = 1),
45+
_inceptionblock(192, 64, 96, 128, 16, 32, 32),
46+
_inceptionblock(256, 128, 128, 192, 32, 96, 64),
47+
MaxPool((3, 3); stride = 2, pad = 1),
48+
_inceptionblock(480, 192, 96, 208, 16, 48, 64),
49+
_inceptionblock(512, 160, 112, 224, 24, 64, 64),
50+
_inceptionblock(512, 128, 128, 256, 24, 64, 64),
51+
_inceptionblock(512, 112, 144, 288, 32, 64, 64),
52+
_inceptionblock(528, 256, 160, 320, 32, 128, 128),
53+
MaxPool((3, 3); stride = 2, pad = 1),
54+
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
55+
_inceptionblock(832, 384, 192, 384, 48, 128, 128))
56+
classifier = create_classifier(1024, nclasses; dropout_rate = 0.4)
57+
return Chain(backbone, classifier)
6158
end
6259

6360
"""
64-
GoogLeNet(; pretrain::Bool = false, nclasses::Integer = 1000)
61+
GoogLeNet(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
6562
6663
Create an Inception-v1 model (commonly referred to as `GoogLeNet`)
6764
([reference](https://arxiv.org/abs/1409.4842v1)).
@@ -82,8 +79,9 @@ struct GoogLeNet
8279
end
8380
@functor GoogLeNet
8481

85-
function GoogLeNet(; pretrain::Bool = false, nclasses::Integer = 1000)
86-
layers = googlenet(; nclasses = nclasses)
82+
function GoogLeNet(; pretrain::Bool = false, inchannels::Integer = 3,
83+
nclasses::Integer = 1000)
84+
layers = googlenet(; inchannels, nclasses)
8785
if pretrain
8886
loadpretrain!(layers, "GoogLeNet")
8987
end

src/convnets/inception/inceptionresnetv2.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,24 +77,23 @@ Creates an InceptionResNetv2 model.
7777
"""
7878
function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0,
7979
nclasses::Integer = 1000)
80-
body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
81-
conv_norm((3, 3), 32, 32)...,
82-
conv_norm((3, 3), 32, 64; pad = 1)...,
83-
MaxPool((3, 3); stride = 2),
84-
conv_norm((3, 3), 64, 80)...,
85-
conv_norm((3, 3), 80, 192)...,
86-
MaxPool((3, 3); stride = 2),
87-
mixed_5b(),
88-
[block35(0.17f0) for _ in 1:10]...,
89-
mixed_6a(),
90-
[block17(0.10f0) for _ in 1:20]...,
91-
mixed_7a(),
92-
[block8(0.20f0) for _ in 1:9]...,
93-
block8(; activation = relu),
94-
conv_norm((1, 1), 2080, 1536)...)
95-
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate),
96-
Dense(1536, nclasses))
97-
return Chain(body, head)
80+
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
81+
conv_norm((3, 3), 32, 32)...,
82+
conv_norm((3, 3), 32, 64; pad = 1)...,
83+
MaxPool((3, 3); stride = 2),
84+
conv_norm((3, 3), 64, 80)...,
85+
conv_norm((3, 3), 80, 192)...,
86+
MaxPool((3, 3); stride = 2),
87+
mixed_5b(),
88+
[block35(0.17f0) for _ in 1:10]...,
89+
mixed_6a(),
90+
[block17(0.10f0) for _ in 1:20]...,
91+
mixed_7a(),
92+
[block8(0.20f0) for _ in 1:9]...,
93+
block8(; activation = relu),
94+
conv_norm((1, 1), 2080, 1536)...)
95+
classifier = create_classifier(1536, nclasses; dropout_rate)
96+
return Chain(backbone, classifier)
9897
end
9998

10099
"""

src/convnets/inception/inceptionv3.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,29 +136,26 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
136136
- `nclasses`: the number of output classes
137137
"""
138138
function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
139-
layer = Chain(Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
140-
conv_norm((3, 3), 32, 32)...,
141-
conv_norm((3, 3), 32, 64; pad = 1)...,
142-
MaxPool((3, 3); stride = 2),
143-
conv_norm((1, 1), 64, 80)...,
144-
conv_norm((3, 3), 80, 192)...,
145-
MaxPool((3, 3); stride = 2),
146-
inceptionv3_a(192, 32),
147-
inceptionv3_a(256, 64),
148-
inceptionv3_a(288, 64),
149-
inceptionv3_b(288),
150-
inceptionv3_c(768, 128),
151-
inceptionv3_c(768, 160),
152-
inceptionv3_c(768, 160),
153-
inceptionv3_c(768, 192),
154-
inceptionv3_d(768),
155-
inceptionv3_e(1280),
156-
inceptionv3_e(2048)),
157-
Chain(AdaptiveMeanPool((1, 1)),
158-
Dropout(0.2),
159-
MLUtils.flatten,
160-
Dense(2048, nclasses)))
161-
return layer
139+
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
140+
conv_norm((3, 3), 32, 32)...,
141+
conv_norm((3, 3), 32, 64; pad = 1)...,
142+
MaxPool((3, 3); stride = 2),
143+
conv_norm((1, 1), 64, 80)...,
144+
conv_norm((3, 3), 80, 192)...,
145+
MaxPool((3, 3); stride = 2),
146+
inceptionv3_a(192, 32),
147+
inceptionv3_a(256, 64),
148+
inceptionv3_a(288, 64),
149+
inceptionv3_b(288),
150+
inceptionv3_c(768, 128),
151+
inceptionv3_c(768, 160),
152+
inceptionv3_c(768, 160),
153+
inceptionv3_c(768, 192),
154+
inceptionv3_d(768),
155+
inceptionv3_e(1280),
156+
inceptionv3_e(2048))
157+
classifier = create_classifier(2048, nclasses; dropout_rate = 0.2)
158+
return Chain(backbone, classifier)
162159
end
163160

164161
"""

src/convnets/inception/inceptionv4.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -95,31 +95,30 @@ Create an Inceptionv4 model.
9595
"""
9696
function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3,
9797
nclasses::Integer = 1000)
98-
body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
99-
conv_norm((3, 3), 32, 32)...,
100-
conv_norm((3, 3), 32, 64; pad = 1)...,
101-
mixed_3a(),
102-
mixed_4a(),
103-
mixed_5a(),
104-
inceptionv4_a(),
105-
inceptionv4_a(),
106-
inceptionv4_a(),
107-
inceptionv4_a(),
108-
reduction_a(), # mixed_6a
109-
inceptionv4_b(),
110-
inceptionv4_b(),
111-
inceptionv4_b(),
112-
inceptionv4_b(),
113-
inceptionv4_b(),
114-
inceptionv4_b(),
115-
inceptionv4_b(),
116-
reduction_b(), # mixed_7a
117-
inceptionv4_c(),
118-
inceptionv4_c(),
119-
inceptionv4_c())
120-
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate),
121-
Dense(1536, nclasses))
122-
return Chain(body, head)
98+
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)...,
99+
conv_norm((3, 3), 32, 32)...,
100+
conv_norm((3, 3), 32, 64; pad = 1)...,
101+
mixed_3a(),
102+
mixed_4a(),
103+
mixed_5a(),
104+
inceptionv4_a(),
105+
inceptionv4_a(),
106+
inceptionv4_a(),
107+
inceptionv4_a(),
108+
reduction_a(), # mixed_6a
109+
inceptionv4_b(),
110+
inceptionv4_b(),
111+
inceptionv4_b(),
112+
inceptionv4_b(),
113+
inceptionv4_b(),
114+
inceptionv4_b(),
115+
inceptionv4_b(),
116+
reduction_b(), # mixed_7a
117+
inceptionv4_c(),
118+
inceptionv4_c(),
119+
inceptionv4_c())
120+
classifier = create_classifier(1536, nclasses; dropout_rate)
121+
return Chain(backbone, classifier)
123122
end
124123

125124
"""

src/convnets/inception/xception.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,17 @@ Creates an Xception model.
5757
- `nclasses`: the number of output classes.
5858
"""
5959
function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000)
60-
body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)...,
61-
conv_norm((3, 3), 32, 64; bias = false)...,
62-
xception_block(64, 128, 2; stride = 2, start_with_relu = false),
63-
xception_block(128, 256, 2; stride = 2),
64-
xception_block(256, 728, 2; stride = 2),
65-
[xception_block(728, 728, 3) for _ in 1:8]...,
66-
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
67-
depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
68-
depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
69-
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate),
70-
Dense(2048, nclasses))
71-
return Chain(body, head)
60+
backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)...,
61+
conv_norm((3, 3), 32, 64; bias = false)...,
62+
xception_block(64, 128, 2; stride = 2, start_with_relu = false),
63+
xception_block(128, 256, 2; stride = 2),
64+
xception_block(256, 728, 2; stride = 2),
65+
[xception_block(728, 728, 3) for _ in 1:8]...,
66+
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
67+
depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
68+
depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
69+
classifier = create_classifier(2048, nclasses; dropout_rate)
70+
return Chain(backbone, classifier)
7271
end
7372

7473
"""

0 commit comments

Comments
 (0)