Skip to content

Commit 55565d8

Browse files
committed
Xception
1 parent 40a8351 commit 55565d8

File tree

5 files changed

+138
-25
lines changed

5 files changed

+138
-25
lines changed

src/Metalhead.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,19 @@ include("vit-based/vit.jl")
3737

3838
include("pretrain.jl")
3939

40-
export AlexNet,
41-
VGG, VGG11, VGG13, VGG16, VGG19,
42-
GoogLeNet,
43-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
44-
Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2,
45-
SqueezeNet,
46-
ResNeXt,
40+
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
41+
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
4742
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
48-
MobileNetv1, MobileNetv2, MobileNetv3,
43+
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
44+
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3,
4945
MLPMixer, ResMLP, gMLP,
5046
ViT,
5147
ConvMixer, ConvNeXt
5248

5349
# use Flux._big_show to pretty print large models
54-
for T in (:AlexNet, :VGG, :GoogLeNet, :ResNet, :ResNeXt, :Inceptionv3,
55-
:SqueezeNet, :DenseNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
50+
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet,
51+
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
52+
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
5653
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
5754
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
5855
end

src/convnets/inception.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,103 @@ end
478478

479479
backbone(m::InceptionResNetv2) = m.layers[1]
480480
classifier(m::InceptionResNetv2) = m.layers[2]
481+
482+
## Xception
483+
484+
"""
485+
xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true,
486+
grow_first = true)
487+
488+
Create an Xception block.
489+
([reference](https://arxiv.org/abs/1610.02357))
490+
491+
# Arguments
492+
493+
- inchannels: number of input channels.
494+
- outchannels: number of output channels.
495+
- nrepeats: number of repeats of depthwise separable convolution layers.
496+
- stride: stride by which to downsample the input.
497+
- start_with_relu: if true, start the block with a ReLU activation.
498+
- grow_first: if true, increase the number of channels at the first convolution.
499+
"""
500+
function xception_block(inchannels, outchannels, nrepeats; stride = 1,
501+
start_with_relu = true,
502+
grow_first = true)
503+
if outchannels != inchannels || stride != 1
504+
skip = conv_bn((1, 1), inchannels, outchannels, identity; stride = stride,
505+
bias = false)
506+
else
507+
skip = [identity]
508+
end
509+
layers = []
510+
for i in 1:nrepeats
511+
if grow_first
512+
inc = i == 1 ? inchannels : outchannels
513+
outc = outchannels
514+
else
515+
inc = inchannels
516+
outc = i == nrepeats ? outchannels : inchannels
517+
end
518+
push!(layers, Base.Fix1(applyactivation, relu))
519+
append!(layers,
520+
depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false,
521+
use_bn1 = false, use_bn2 = false))
522+
push!(layers, BatchNorm(outc))
523+
end
524+
layers = start_with_relu ? layers : layers[2:end]
525+
push!(layers, MaxPool((3, 3); stride = stride, pad = 1))
526+
return Parallel(+, Chain(skip...), Chain(layers...))
527+
end
528+
529+
"""
530+
xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
531+
532+
Creates an Xception model.
533+
([reference](https://arxiv.org/abs/1610.02357))
534+
535+
# Arguments
536+
537+
- inchannels: number of input channels.
538+
- dropout: rate of dropout in classifier head.
539+
- nclasses: the number of output classes.
540+
"""
541+
function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
542+
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)...,
543+
conv_bn((3, 3), 32, 64; bias = false)...,
544+
xception_block(64, 128, 2; stride = 2, start_with_relu = false),
545+
xception_block(128, 256, 2; stride = 2),
546+
xception_block(256, 728, 2; stride = 2),
547+
[xception_block(728, 728, 3) for _ in 1:8]...,
548+
xception_block(728, 1024, 2; stride = 2, grow_first = false),
549+
depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
550+
depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
551+
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses))
552+
return Chain(body, head)
553+
end
554+
555+
struct Xception
556+
layers::Any
557+
end
558+
559+
"""
560+
Xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
561+
562+
Creates an Xception model.
563+
([reference](https://arxiv.org/abs/1610.02357))
564+
565+
# Arguments
566+
- inchannels: number of input channels.
567+
- dropout: rate of dropout in classifier head.
568+
- nclasses: the number of output classes.
569+
"""
570+
function Xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
571+
layers = xception(; inchannels, dropout, nclasses)
572+
return Xception(layers)
573+
end
574+
575+
@functor Xception
576+
577+
(m::Xception)(x) = m.layers(x)
578+
579+
backbone(m::Xception) = m.layers[1]
580+
classifier(m::Xception) = m.layers[2]

src/convnets/mobilenet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function mobilenetv1(width_mult, config;
3636
for _ in 1:nrepeats
3737
layer = dw ?
3838
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
39-
stride = stride, pad = 1) :
39+
stride = stride, pad = 1, bias = false) :
4040
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
4141
append!(layers, layer)
4242
inchannels = outch

src/layers/conv.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
3-
rev = false, preact = true,
4-
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
5-
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)
3+
rev = false, preact = false, use_bn = true,
4+
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1.0f-5, momentum = 1.0f-1,
5+
kwargs...)
66
77
Create a convolution + batch normalization pair with activation.
88
@@ -15,6 +15,8 @@ Create a convolution + batch normalization pair with activation.
1515
- `rev`: set to `true` to place the batch norm before the convolution
1616
- `preact`: set to `true` to place the activation function before the batch norm
1717
(only compatible with `rev = false`)
18+
- `use_bn`: set to `false` to disable batch normalization
19+
(only compatible with `rev = false` and `preact = false`)
1820
- `stride`: stride of the convolution kernel
1921
- `pad`: padding of the convolution kernel
2022
- `dilation`: dilation of the convolution kernel
@@ -24,9 +26,13 @@ Create a convolution + batch normalization pair with activation.
2426
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
2527
"""
2628
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
27-
rev = false, preact = false,
29+
rev = false, preact = false, use_bn = true,
2830
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1.0f-5, momentum = 1.0f-1,
2931
kwargs...)
32+
if !use_bn
33+
(preact || rev) ? throw("preact only supported with `use_bn = true`") :
34+
return [Conv(kernelsize, inplanes => outplanes, activation; kwargs...)]
35+
end
3036
layers = []
3137
if rev
3238
activations = (conv = activation, bn = identity)
@@ -49,18 +55,18 @@ end
4955

5056
"""
5157
depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu;
52-
rev = false,
53-
stride = 1, pad = 0, dilation = 1, [bias, weight, init],
54-
initβ = Flux.zeros32, initγ = Flux.ones32,
55-
ϵ = 1f-5, momentum = 1f-1)
58+
rev = false, use_bn1 = true, use_bn2 = true,
59+
initβ = Flux.zeros32, initγ = Flux.ones32,
60+
ϵ = 1.0f-5, momentum = 1.0f-1,
61+
stride = 1, kwargs...)
5662
57-
Create a depthwise separable convolution chain as used in MobileNet v1.
63+
Create a depthwise separable convolution chain as used in MobileNetv1.
5864
This is sequence of layers:
5965
6066
- a `kernelsize` depthwise convolution from `inplanes => inplanes`
61-
- a batch norm layer + `activation`
67+
- a batch norm layer + `activation` (if `use_bn1`; otherwise `activation` is applied to the convolution output)
6268
- a `kernelsize` convolution from `inplanes => outplanes`
63-
- a batch norm layer + `activation`
69+
- a batch norm layer + `activation` (if `use_bn2`; otherwise `activation` is applied to the convolution output)
6470
6571
See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
6672
@@ -71,6 +77,8 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
7177
- `outplanes`: number of output feature maps
7278
- `activation`: the activation function for the final layer
7379
- `rev`: set to `true` to place the batch norm before the convolution
80+
- `use_bn1`: set to `true` to use a batch norm after the depthwise convolution
81+
- `use_bn2`: set to `true` to use a batch norm after the pointwise convolution
7482
- `stride`: stride of the first convolution kernel
7583
- `pad`: padding of the first convolution kernel
7684
- `dilation`: dilation of the first convolution kernel
@@ -79,16 +87,16 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
7987
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
8088
"""
8189
function depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu;
82-
rev = false,
90+
rev = false, use_bn1 = true, use_bn2 = true,
8391
initβ = Flux.zeros32, initγ = Flux.ones32,
8492
ϵ = 1.0f-5, momentum = 1.0f-1,
8593
stride = 1, kwargs...)
8694
return vcat(conv_bn(kernelsize, inplanes, inplanes, activation;
8795
rev = rev, initβ = initβ, initγ = initγ,
88-
ϵ = ϵ, momentum = momentum,
96+
ϵ = ϵ, momentum = momentum, use_bn = use_bn1,
8997
stride = stride, groups = Int(inplanes), kwargs...),
9098
conv_bn((1, 1), inplanes, outplanes, activation;
91-
rev = rev, initβ = initβ, initγ = initγ,
99+
rev = rev, initβ = initβ, initγ = initγ, use_bn = use_bn2,
92100
ϵ = ϵ, momentum = momentum))
93101
end
94102

src/utilities.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ Equivalent to `permutedims(x, perm)`.
5555
"""
5656
swapdims(perm) = Base.Fix2(permutedims, perm)
5757

58+
"""
59+
applyactivation(activation, x)
60+
61+
Apply an activation function to a given input.
62+
Equivalent to `activation.(x)`.
63+
"""
64+
applyactivation(activation, x) = activation.(x)
65+
5866
# Utility function for pretty printing large models
5967
function _maybe_big_show(io, model)
6068
if isdefined(Flux, :_big_show)

0 commit comments

Comments
 (0)