Skip to content

Commit 8e6f929

Browse files
committed
More minor tweaks
1. Support `pretrain` for the Inception model APIs 2. Group deprecations in a single source file to make stuff more organised 3. Random formatting nitpicks 4. Use a plain broadcast instead of `applyactivation`
1 parent 55565d8 commit 8e6f929

File tree

7 files changed

+55
-46
lines changed

7 files changed

+55
-46
lines changed

src/convnets/densenet.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
7272
layers = []
7373
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
7474
push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1)))
75-
7675
outplanes = 0
7776
for (i, rates) in enumerate(growth_rates)
7877
outplanes = inplanes + sum(rates)
@@ -82,7 +81,6 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
8281
inplanes = floor(Int, outplanes * reduction)
8382
end
8483
push!(layers, BatchNorm(outplanes, relu))
85-
8684
return Chain(Chain(layers),
8785
Chain(AdaptiveMeanPool((1, 1)),
8886
MLUtils.flatten,
@@ -131,7 +129,6 @@ function DenseNet(nblocks::NTuple{N, <:Integer};
131129
layers = densenet(nblocks; growth_rate = growth_rate,
132130
reduction = reduction,
133131
nclasses = nclasses)
134-
135132
return DenseNet(layers)
136133
end
137134

@@ -164,13 +161,6 @@ See also [`Metalhead.densenet`](#).
164161
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
165162
@assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))."
166163
model = DenseNet(densenet_config[config]; nclasses = nclasses)
167-
168164
pretrain && loadpretrain!(model, string("DenseNet", config))
169165
return model
170166
end
171-
172-
# deprecations
173-
@deprecate DenseNet121(; kw...) DenseNet(121; kw...)
174-
@deprecate DenseNet161(; kw...) DenseNet(161; kw...)
175-
@deprecate DenseNet169(; kw...) DenseNet(169; kw...)
176-
@deprecate DenseNet201(; kw...) DenseNet(201; kw...)

src/convnets/inception.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
134134
# Arguments
135135
136136
- `nclasses`: the number of output classes
137-
138-
!!! warning
139-
140-
`inceptionv3` does not currently support pretrained weights.
141137
"""
142138
function inceptionv3(; nclasses = 1000)
143139
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)...,
@@ -197,9 +193,6 @@ end
197193
backbone(m::Inceptionv3) = m.layers[1]
198194
classifier(m::Inceptionv3) = m.layers[2]
199195

200-
@deprecate Inception3 Inceptionv3
201-
@deprecate inception3 inceptionv3
202-
203196
## Inceptionv4
204197

205198
function mixed_3a()
@@ -325,23 +318,29 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
325318
end
326319

327320
"""
328-
Inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
321+
Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
329322
330323
Creates an Inceptionv4 model.
331324
([reference](https://arxiv.org/abs/1602.07261))
332325
333326
# Arguments
334327
328+
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
335329
- inchannels: number of input channels.
336330
- dropout: rate of dropout in classifier head.
337331
- nclasses: the number of output classes.
332+
333+
!!! warning
334+
335+
`Inceptionv4`` does not currently support pretrained weights.
338336
"""
339337
struct Inceptionv4
340338
layers::Any
341339
end
342340

343-
function Inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
341+
function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
344342
layers = inceptionv4(; inchannels, dropout, nclasses)
343+
pretrain && loadpretrain!(layers, "Inceptionv4")
345344
return Inceptionv4(layers)
346345
end
347346

@@ -452,23 +451,30 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
452451
end
453452

454453
"""
455-
InceptionResNetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
454+
InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
456455
457456
Creates an InceptionResNetv2 model.
458457
([reference](https://arxiv.org/abs/1602.07261))
459458
460459
# Arguments
461460
461+
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
462462
- inchannels: number of input channels.
463463
- dropout: rate of dropout in classifier head.
464464
- nclasses: the number of output classes.
465+
466+
!!! warning
467+
468+
`InceptionResNetv2` does not currently support pretrained weights.
465469
"""
466470
struct InceptionResNetv2
467471
layers::Any
468472
end
469473

470-
function InceptionResNetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
474+
function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0,
475+
nclasses = 1000)
471476
layers = inceptionresnetv2(; inchannels, dropout, nclasses)
477+
pretrain && loadpretrain!(layers, "InceptionResNetv2")
472478
return InceptionResNetv2(layers)
473479
end
474480

@@ -515,7 +521,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
515521
inc = inchannels
516522
outc = i == nrepeats ? outchannels : inchannels
517523
end
518-
push!(layers, Base.Fix1(applyactivation, relu))
524+
push!(layers, x -> relu.(x))
519525
append!(layers,
520526
depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false,
521527
use_bn1 = false, use_bn2 = false))
@@ -557,15 +563,21 @@ struct Xception
557563
end
558564

559565
"""
560-
Xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
566+
Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
561567
562568
Creates an Xception model.
563569
([reference](https://arxiv.org/abs/1610.02357))
564570
565571
# Arguments
572+
573+
- pretrain: set to `true` to load the pre-trained weights for ImageNet.
566574
- inchannels: number of input channels.
567575
- dropout: rate of dropout in classifier head.
568-
- nclasses: the number of output classes.
576+
- nclasses: the number of output classes.
577+
578+
!!! warning
579+
580+
`Xception` does not currently support pretrained weights.
569581
"""
570582
function Xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
571583
layers = xception(; inchannels, dropout, nclasses)

src/convnets/resnet.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,8 @@ resnet50_v1 = ResNet([1, 1, 4], [3, 4, 6, 3], :B; block = Metalhead.bottleneck_v
256256
"""
257257
function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000)
258258
@assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))"
259-
260259
config, block = resnet_config[depth]
261260
model = ResNet(config...; block = block, nclasses = nclasses)
262261
pretrain && loadpretrain!(model, string("ResNet", depth))
263262
return model
264263
end
265-
266-
# Compat with Metalhead 0.6; remove in 0.7
267-
@deprecate ResNet18(; kw...) ResNet(18; kw...)
268-
@deprecate ResNet34(; kw...) ResNet(34; kw...)
269-
@deprecate ResNet50(; kw...) ResNet(50; kw...)
270-
@deprecate ResNet101(; kw...) ResNet(101; kw...)
271-
@deprecate ResNet152(; kw...) ResNet(152; kw...)

src/convnets/vgg.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,3 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses
177177
end
178178
return model
179179
end
180-
181-
# deprecations
182-
@deprecate VGG11(; kw...) VGG(11; kw...)
183-
@deprecate VGG13(; kw...) VGG(13; kw...)
184-
@deprecate VGG16(; kw...) VGG(16; kw...)
185-
@deprecate VGG19(; kw...) VGG(19; kw...)

src/deprecations.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
## Deprecated in 0.7.0
2+
3+
@deprecate DenseNet121(; kw...) DenseNet(121; kw...)
4+
@deprecate DenseNet161(; kw...) DenseNet(161; kw...)
5+
@deprecate DenseNet169(; kw...) DenseNet(169; kw...)
6+
@deprecate DenseNet201(; kw...) DenseNet(201; kw...)
7+
8+
@deprecate VGG11(; kw...) VGG(11; kw...)
9+
@deprecate VGG13(; kw...) VGG(13; kw...)
10+
@deprecate VGG16(; kw...) VGG(16; kw...)
11+
@deprecate VGG19(; kw...) VGG(19; kw...)
12+
13+
@deprecate ResNet18(; kw...) ResNet(18; kw...)
14+
@deprecate ResNet34(; kw...) ResNet(34; kw...)
15+
@deprecate ResNet50(; kw...) ResNet(50; kw...)
16+
@deprecate ResNet101(; kw...) ResNet(101; kw...)
17+
@deprecate ResNet152(; kw...) ResNet(152; kw...)
18+
19+
# Deprecated in 0.7.3
20+
21+
@deprecate Inception3 Inceptionv3
22+
@deprecate inception3 inceptionv3

src/utilities.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ Equivalent to `permutedims(x, perm)`.
5555
"""
5656
swapdims(perm) = Base.Fix2(permutedims, perm)
5757

58-
"""
59-
applyactivation(activation, x)
60-
61-
Apply an activation function to a given input.
62-
Equivalent to `activation.(x)`.
63-
"""
64-
applyactivation(activation, x) = activation.(x)
65-
6658
# Utility function for pretty printing large models
6759
function _maybe_big_show(io, model)
6860
if isdefined(Flux, :_big_show)

test/convnets.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ GC.gc()
108108
@test size(m(x_299)) == (1000, 2)
109109
@test gradtest(m, x_299)
110110
end
111+
GC.safepoint()
112+
GC.gc()
113+
@testset "Xception" begin
114+
m = Xception()
115+
@test size(m(x_299)) == (1000, 2)
116+
@test gradtest(m, x_299)
117+
end
111118
end
112119

113120
GC.safepoint()

0 commit comments

Comments
 (0)