Skip to content

Commit 88f59c2

Browse files
committed
Tweaks to prepare pretrained models
1 parent 7449985 commit 88f59c2

File tree

10 files changed

+34
-27
lines changed

10 files changed

+34
-27
lines changed

src/convnets/densenet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer
135135
_checkconfig(config, keys(DENSENET_CONFIGS))
136136
model = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
137137
if pretrain
138-
loadpretrain!(model, string("DenseNet", config))
138+
loadpretrain!(model, string("densenet", config))
139139
end
140140
return model
141141
end

src/convnets/inception/inceptionv3.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
## Inceptionv3
2-
31
"""
42
inceptionv3_a(inplanes, pool_proj)
53

src/convnets/resnets/core.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,13 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
280280
end
281281

282282
function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
283-
connection,
284-
classifier_fn)
283+
connection, classifier_fn)
285284
# Build stages of the ResNet
286285
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
287286
backbone = Chain(stem, stage_blocks)
288-
# Build the classifier head
287+
# Add classifier to the backbone
289288
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
290-
classifier = classifier_fn(nfeaturemaps)
291-
return Chain(backbone, classifier)
289+
return Chain(backbone, classifier_fn(nfeaturemaps))
292290
end
293291

294292
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};

src/convnets/resnets/resnet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
2727
_checkconfig(depth, keys(RESNET_CONFIGS))
2828
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses)
2929
if pretrain
30-
loadpretrain!(layers, string("ResNet", depth))
30+
loadpretrain!(layers, string("resnet", depth))
3131
end
3232
return ResNet(layers)
3333
end

src/convnets/resnets/resnext.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
3232
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
3333
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain
35-
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
35+
loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d"))
3636
end
3737
return ResNeXt(layers)
3838
end

src/convnets/squeezenet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ function SqueezeNet(; pretrain::Bool = false, inchannels::Integer = 3,
7777
nclasses::Integer = 1000)
7878
layers = squeezenet(; inchannels, nclasses)
7979
if pretrain
80-
loadpretrain!(layers, "SqueezeNet")
80+
loadpretrain!(layers, "squeezenet")
8181
end
8282
return SqueezeNet(layers)
8383
end

src/layers/mlp.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,19 @@ Creates a classifier head to be used for models.
6666
function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
6767
use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)),
6868
dropout_rate = nothing)
69-
# Pooling
69+
# Decide whether to flatten the input or not
7070
flatten_in_pool = !use_conv && pool_layer !== identity
7171
if use_conv
7272
@assert pool_layer === identity
7373
"`pool_layer` must be identity if `use_conv` is true"
7474
end
75-
global_pool = flatten_in_pool ? [pool_layer, MLUtils.flatten] : [pool_layer]
75+
classifier = []
76+
flatten_in_pool ? push!(classifier, pool_layer, MLUtils.flatten) :
77+
push!(classifier, pool_layer)
78+
# Dropout is applied after the pooling layer
79+
isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate))
7680
# Fully-connected layer
77-
fc = use_conv ? Conv((1, 1), inplanes => nclasses, activation) :
78-
Dense(inplanes => nclasses, activation)
79-
drop = isnothing(dropout_rate) ? [] : [Dropout(dropout_rate)]
80-
return Chain(global_pool..., drop..., fc)
81+
use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) :
82+
push!(classifier, Dense(inplanes => nclasses, activation))
83+
return Chain(classifier...)
8184
end

src/utilities.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,5 @@ end
7373

7474
# Utility function for depth and configuration checks in models
7575
function _checkconfig(config, configs)
76-
@assert config in configs
77-
return "Invalid configuration. Must be one of $(sort(collect(configs)))."
76+
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
7877
end

test/convnets.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@ end
2222

2323
@testset "ResNet" begin
2424
# Tests for pretrained ResNets
25-
## TODO: find a way to port pretrained models to the new ResNet API
2625
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
2726
m = ResNet(sz)
2827
@test size(m(x_224)) == (1000, 1)
29-
# if (ResNet, sz) in PRETRAINED_MODELS
30-
# @test acctest(ResNet(sz, pretrain = true))
31-
# else
32-
# @test_throws ArgumentError ResNet(sz, pretrain = true)
33-
# end
28+
if (ResNet, sz) in PRETRAINED_MODELS
29+
@test acctest(ResNet(sz, pretrain = true))
30+
else
31+
@test_throws ArgumentError ResNet(sz, pretrain = true)
32+
end
3433
end
3534

3635
@testset "resnet" begin
@@ -79,9 +78,9 @@ end
7978
m = ResNeXt(depth; cardinality, base_width)
8079
@test size(m(x_224)) == (1000, 1)
8180
if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS
82-
@test acctest(ResNeXt(depth, pretrain = true))
81+
@test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true))
8382
else
84-
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
83+
@test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true)
8584
end
8685
@test gradtest(m, x_224)
8786
_gc()

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@ const PRETRAINED_MODELS = [
88
(VGG, 13, false),
99
(VGG, 16, false),
1010
(VGG, 19, false),
11+
SqueezeNet,
1112
(ResNet, 18),
1213
(ResNet, 34),
1314
(ResNet, 50),
1415
(ResNet, 101),
1516
(ResNet, 152),
17+
(WideResNet, 50),
18+
(WideResNet, 101),
19+
(ResNeXt, 50, 32, 4),
20+
(ResNeXt, 101, 64, 4),
21+
(ResNeXt, 101, 32, 8),
22+
(DenseNet, 121),
23+
(DenseNet, 161),
24+
(DenseNet, 169),
25+
(DenseNet, 201),
1626
]
1727

1828
function _gc()

0 commit comments

Comments
 (0)