From 49faf0986414e99bf98791fa77ffe7a112e5da04 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 4 Aug 2022 18:16:47 +0530 Subject: [PATCH 01/34] Initial commit for EfficientNetv2 --- src/Metalhead.jl | 13 ++-- .../{ => efficientnet}/efficientnet.jl | 25 +++--- src/convnets/efficientnet/efficientnetv2.jl | 78 +++++++++++++++++++ src/convnets/mobilenet/mobilenetv2.jl | 3 +- src/convnets/mobilenet/mobilenetv3.jl | 11 ++- src/convnets/resnets/resnext.jl | 3 +- src/layers/conv.jl | 2 +- src/layers/mlp.jl | 7 +- 8 files changed, 114 insertions(+), 28 deletions(-) rename src/convnets/{ => efficientnet}/efficientnet.jl (82%) create mode 100644 src/convnets/efficientnet/efficientnetv2.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index aa236454c..cad13afc9 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -33,6 +33,9 @@ include("convnets/inception/inceptionv3.jl") include("convnets/inception/inceptionv4.jl") include("convnets/inception/inceptionresnetv2.jl") include("convnets/inception/xception.jl") +## EfficientNets +include("convnets/efficientnet/efficientnet.jl") +include("convnets/efficientnet/efficientnetv2.jl") ## MobileNets include("convnets/mobilenet/mobilenetv1.jl") include("convnets/mobilenet/mobilenetv2.jl") @@ -40,7 +43,6 @@ include("convnets/mobilenet/mobilenetv3.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") -include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") @@ -61,13 +63,14 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, - SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, + SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, EfficientNetv2, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, - :Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, - :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, +for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, + :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, + :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, + :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :EfficientNetv2, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl similarity index 82% rename from src/convnets/efficientnet.jl rename to src/convnets/efficientnet/efficientnet.jl index 91986fb92..795fbd592 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -29,28 +29,27 @@ function efficientnet(scalings::NTuple{2, Real}, wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - out_channels = _round_channels(scalew(32), 8) - stem = conv_norm((3, 3), inchannels, out_channels, swish; bias = false, stride = 2, + outplanes = _round_channels(scalew(32), 8) + stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, pad = SamePad()) blocks = [] for (n, k, s, e, i, o) in block_configs - in_channels = _round_channels(scalew(i), 8) - out_channels = _round_channels(scalew(o), 8) + inchannels = _round_channels(scalew(i), 8) + outplanes = _round_channels(scalew(o), 8) repeats = scaled(n) push!(blocks, - invertedresidual((k, k), in_channels, out_channels, swish; expansion = e, + invertedresidual((k, k), in_channels, outplanes, swish; expansion = e, stride = s, reduction = 4)) for _ in 1:(repeats - 1) push!(blocks, - invertedresidual((k, k), out_channels, out_channels, swish; expansion = e, + invertedresidual((k, k), outplanes, outplanes, swish; expansion = e, stride = 1, reduction = 4)) end end - head_out_channels = _round_channels(max_width, 8) + headplanes = _round_channels(max_width, 8) append!(blocks, - conv_norm((1, 1), out_channels, head_out_channels, swish; - bias = false, pad = SamePad())) - return Chain(Chain(stem..., blocks...), create_classifier(head_out_channels, nclasses)) + conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) + return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) end # n: # of block repetitions @@ -101,9 +100,11 @@ struct EfficientNet end @functor EfficientNet -function EfficientNet(config::Symbol; pretrain::Bool = false) +function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS) + model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS; + inchannels, nclasses) if pretrain loadpretrain!(model, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl new file mode 100644 index 000000000..58a7aea46 --- /dev/null +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -0,0 +1,78 @@ +function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, + width_mult::Real = 1.0, inchannels::Integer = 3, + nclasses::Integer = 1000) + # building first layer + inplanes = _round_channels(24 * width_mult, 8) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, + bias = false)) + # building inverted residual blocks + for (t, c, n, s, r) in config + outplanes = _round_channels(c * width_mult, 8) + for i in 1:n + push!(layers, + invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, + stride = i == 1 ? s : 1, + reduction = r == 1 ? 4 : nothing)) + inplanes = outplanes + end + end + # building last layers + outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, 8) : + max_width + append!(layers, conv_norm((1, 1), inplanes, outplanes, swish; bias = false)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) +end + +const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE + (1, 24, 2, 1, 0), + (4, 48, 4, 2, 0), + (4, 64, 4, 2, 0), + (4, 128, 6, 2, 1), + (6, 160, 9, 1, 1), + (6, 256, 15, 2, 1)], + :medium => [# t, c, n, s, SE + (1, 24, 3, 1, 0), + (4, 48, 5, 2, 0), + (4, 80, 5, 2, 0), + (4, 160, 7, 2, 1), + (6, 176, 14, 1, 1), + (6, 304, 18, 2, 1), + (6, 512, 5, 1, 1)], + :large => [# t, c, n, s, SE + (1, 32, 4, 1, 0), + (4, 64, 8, 2, 0), + (4, 96, 8, 2, 0), + (4, 192, 16, 2, 1), + (6, 256, 24, 1, 1), + (6, 512, 32, 2, 1), + (6, 640, 8, 1, 1)], + :xlarge => [# t, c, n, s, SE + (1, 32, 4, 1, 0), + (4, 64, 8, 2, 0), + (4, 96, 8, 2, 0), + (4, 192, 16, 2, 1), + (6, 256, 24, 1, 1), + (6, 512, 32, 2, 1), + (6, 640, 8, 1, 1)]) + +struct EfficientNetv2 + layers::Any +end +@functor EfficientNetv2 + +function EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) + layers = efficientnetv2(EFFNETV2_CONFIGS[config]; width_mult, inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnetv2")) + end + return EfficientNetv2(layers) +end + +(m::EfficientNetv2)(x) = m.layers(x) + +backbone(m::EfficientNetv2) = m.layers[1] +classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 84162e985..c73644073 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -29,7 +29,8 @@ function mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; # building first layer inplanes = _round_channels(32 * width_mult, divisor) layers = [] - append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + append!(layers, + conv_norm((3, 3), inchannels, inplanes; bias = false, pad = 1, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, divisor) diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 7d06ab14d..607069bdd 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -24,14 +24,14 @@ Create a MobileNetv3 model. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes """ -function mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; +function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1024, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, + conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1, bias = false)) explanes = 0 # building inverted residual blocks @@ -45,9 +45,8 @@ function mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; inplanes = outplanes end # building last layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel + output_channel = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : + max_width append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, output_channel, hardswish), @@ -119,7 +118,7 @@ function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = fals inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, [:small, :large]) max_width = (config == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[config]; max_width, inchannels, + layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv3", config)) diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 8c43d2f62..bb589b97b 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -32,7 +32,8 @@ function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = _checkconfig(depth, keys(LRESNET_CONFIGS)) layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain - loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d")) + loadpretrain!(layers, + string("resnext", depth, "_", cardinality, "x", base_width, "d")) end return ResNeXt(layers) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index de214bcbc..c272d724a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -136,7 +136,7 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer norm_layer = BatchNorm) invres = Chain(conv1..., conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad = pad, groups = hidden_planes)..., + bias = false, stride, pad, groups = hidden_planes)..., selayer, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 0e496097b..500c31811 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -73,12 +73,15 @@ function create_classifier(inplanes::Integer, nclasses::Integer, activation = id "`pool_layer` must be identity if `use_conv` is true" end classifier = [] - flatten_in_pool ? push!(classifier, pool_layer, MLUtils.flatten) : + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else push!(classifier, pool_layer) + end # Dropout is applied after the pooling layer isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) # Fully-connected layer use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) : - push!(classifier, Dense(inplanes => nclasses, activation)) + push!(classifier, Dense(inplanes => nclasses, activation)) return Chain(classifier...) end From d428da0d86d95cc65f384d442361a6faa665ef17 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 5 Aug 2022 11:39:44 +0530 Subject: [PATCH 02/34] Cleanup --- src/convnets/densenet.jl | 3 ++- src/convnets/efficientnet/efficientnet.jl | 2 +- src/convnets/inception/inceptionresnetv2.jl | 7 +++---- src/convnets/inception/inceptionv4.jl | 3 ++- src/convnets/mobilenet/mobilenetv2.jl | 12 ++++++------ src/convnets/mobilenet/mobilenetv3.jl | 17 ++++++++--------- src/convnets/resnets/core.jl | 12 ++++++------ src/convnets/resnets/resnext.jl | 8 +++++--- src/layers/conv.jl | 21 +++++++++++++-------- src/layers/mlp.jl | 7 +++++-- src/utilities.jl | 2 +- src/vit-based/vit.jl | 2 +- 12 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index eb29c4966..badb61a9e 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -132,7 +132,8 @@ end function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32, reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(DENSENET_CONFIGS)) - layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses) + layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, + nclasses) if pretrain loadpretrain!(layers, string("densenet", config)) end diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 795fbd592..518ff15f2 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -38,7 +38,7 @@ function efficientnet(scalings::NTuple{2, Real}, outplanes = _round_channels(scalew(o), 8) repeats = scaled(n) push!(blocks, - invertedresidual((k, k), in_channels, outplanes, swish; expansion = e, + invertedresidual((k, k), inchannels, outplanes, swish; expansion = e, stride = s, reduction = 4)) for _ in 1:(repeats - 1) push!(blocks, diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index a8bfbaefa..54bc59952 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -75,7 +75,7 @@ Creates an InceptionResNetv2 model. - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, +function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., conv_norm((3, 3), 32, 32)..., @@ -96,8 +96,8 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, end """ - InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) + InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -106,7 +106,6 @@ Creates an InceptionResNetv2 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl index cd4971742..7f027da6e 100644 --- a/src/convnets/inception/inceptionv4.jl +++ b/src/convnets/inception/inceptionv4.jl @@ -121,7 +121,8 @@ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, end """ - Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) + Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index c73644073..531a7c9da 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -1,5 +1,5 @@ """ - mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; + mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1280, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -8,9 +8,6 @@ Create a MobileNetv2 model. # Arguments - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - `configs`: A "list of tuples" configuration for each layer that details: + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer @@ -18,11 +15,14 @@ Create a MobileNetv2 model. + `n`: The number of times a block is repeated + `s`: The stride of the convolutional kernel + `a`: The activation function used in the bottleneck layer + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes """ -function mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; +function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1280, inchannels::Integer = 3, nclasses::Integer = 1000) divisor = width_mult == 0.1 ? 4 : 8 @@ -86,7 +86,7 @@ end function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) + layers = mobilenetv2(MOBILENETV2_CONFIGS; width_mult, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv2")) end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 607069bdd..01acf9c54 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -1,5 +1,5 @@ """ - mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; + mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1024, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -8,10 +8,6 @@ Create a MobileNetv3 model. # Arguments - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `configs`: a "list of tuples" configuration for each layer that details: + `k::Integer` - The size of the convolutional kernel @@ -20,6 +16,9 @@ Create a MobileNetv3 model. + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + `s::Integer` - The stride of the convolutional kernel + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4.) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes @@ -45,13 +44,13 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = outplanes end # building last layers - output_channel = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : + headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : max_width append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(explanes, output_channel, hardswish), + Dense(explanes, headplanes, hardswish), Dropout(0.2), - Dense(output_channel, nclasses)) + Dense(headplanes, nclasses)) return Chain(Chain(layers...), classifier) end @@ -117,7 +116,7 @@ end function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, [:small, :large]) - max_width = (config == :large) ? 1280 : 1024 + max_width = config == :large ? 1280 : 1024 layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels, nclasses) if pretrain diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 699edcbe8..afc446d14 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -85,15 +85,15 @@ end # Downsample layer using convolutions. function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, revnorm = false) + norm_layer = BatchNorm, revnorm::Bool = false) return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, pad = SamePad(), stride, bias = false)...) end # Downsample layer using max pooling function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, revnorm = false) - pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) + norm_layer = BatchNorm, revnorm::Bool = false) + pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, bias = false)...) @@ -123,7 +123,7 @@ const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), # Stride for each block in the ResNet model function resnet_stride(stage_idx::Integer, block_idx::Integer) - return (stage_idx == 1 || block_idx != 1) ? 1 : 2 + return stage_idx == 1 || block_idx != 1 ? 1 : 2 end # returns `DropBlock`s for each stage of the ResNet as in timm. @@ -221,7 +221,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_fn = stride != 1 || inplanes != planes * expansion ? downsample_tuple[1] : downsample_tuple[2] drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) @@ -256,7 +256,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_fn = stride != 1 || inplanes != planes * expansion ? downsample_tuple[1] : downsample_tuple[2] drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index bb589b97b..664202d15 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -8,11 +8,13 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. # Arguments - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. Supported configurations are: - - depth 50, cardinality of 32 and base width of 4. - - depth 101, cardinality of 32 and base width of 8. - - depth 101, cardinality of 64 and base width of 4. + + + depth 50, cardinality of 32 and base width of 4. + + depth 101, cardinality of 32 and base width of 8. + + depth 101, cardinality of 64 and base width of 4. - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. - `inchannels`: the number of input channels. diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c272d724a..cdbfc472c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -107,8 +107,13 @@ function depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Inte end """ - invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; - stride, reduction = nothing) + invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Integer} = nothing) + + invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, + activation = relu; stride::Integer, expansion::Real, + reduction::Union{Nothing, Integer} = nothing) Create a basic inverted residual block for MobileNet variants ([reference](https://arxiv.org/abs/1905.02244)). @@ -117,7 +122,8 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer + - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, + specify the keyword argument `expansion`, which calculates - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 @@ -129,7 +135,7 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer reduction::Union{Nothing, Integer} = nothing) @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 - conv1 = (inplanes == hidden_planes) ? (identity,) : + conv1 = inplanes == hidden_planes ? (identity,) : conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false) selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, @@ -139,13 +145,12 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer bias = false, stride, pad, groups = hidden_planes)..., selayer, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) - return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres + return stride == 1 && inplanes == outplanes ? SkipConnection(invres, +) : invres end function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) - hidden_planes = floor(Int, inplanes * expansion) - return invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation; - stride, reduction) + return invertedresidual(kernel_size, inplanes, floor(Int, inplanes * expansion), + outplanes, activation; stride, reduction) end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 500c31811..467df30a4 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -81,7 +81,10 @@ function create_classifier(inplanes::Integer, nclasses::Integer, activation = id # Dropout is applied after the pooling layer isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) # Fully-connected layer - use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) : - push!(classifier, Dense(inplanes => nclasses, activation)) + if use_conv + push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) + else + push!(classifier, Dense(inplanes => nclasses, activation)) + end return Chain(classifier...) end diff --git a/src/utilities.jl b/src/utilities.jl index f5737831c..4a611b5a2 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -6,7 +6,7 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) function _round_channels(channels, divisor, min_value = divisor) new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) # Make sure that round down does not go down by more than 10% - return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels + return new_channels < 0.9 * channels ? new_channels + divisor : new_channels end """ diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 099d00639..75bfb5b07 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -61,7 +61,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, Dropout(emb_dropout_rate), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout_rate), - (pool == :class) ? x -> x[:, 1, :] : seconddimmean), + pool == :class ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end From 093f57317bb9076f9f3e2bb2a2c292de2a569b0b Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 08:38:59 +0530 Subject: [PATCH 03/34] Add docs for EfficientNetv2 Also misc. docs and formatting --- src/convnets/convmixer.jl | 4 +- src/convnets/convnext.jl | 2 +- src/convnets/efficientnet/efficientnetv2.jl | 50 ++++++++++++++++++++- src/convnets/inception/inceptionv3.jl | 3 +- src/convnets/inception/xception.jl | 2 +- src/convnets/mobilenet/mobilenetv1.jl | 4 +- src/convnets/mobilenet/mobilenetv2.jl | 4 +- src/convnets/mobilenet/mobilenetv3.jl | 8 ++-- src/convnets/resnets/core.jl | 18 ++++---- src/convnets/resnets/res2net.jl | 23 +++++----- src/convnets/resnets/resnet.jl | 2 +- src/convnets/resnets/resnext.jl | 3 +- src/convnets/resnets/seresnet.jl | 14 +++--- src/layers/conv.jl | 6 ++- src/layers/embeddings.jl | 2 +- 15 files changed, 99 insertions(+), 46 deletions(-) diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 309989d2d..c7dd058ff 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -13,7 +13,7 @@ Creates a ConvMixer model. - `kernel_size`: kernel size of the convolutional layers - `patch_size`: size of the patches - `activation`: activation function used after the convolutional layers - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), @@ -48,7 +48,7 @@ Creates a ConvMixer model. # Arguments - `config`: the size of the model, either `:base`, `:small` or `:large` - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ struct ConvMixer diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 040a409ab..15271cfed 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -84,7 +84,7 @@ Creates a ConvNeXt model. # Arguments - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`. - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of output classes See also [`Metalhead.convnext`](#). diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl index 58a7aea46..40265bd2c 100644 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -1,3 +1,28 @@ +""" + efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, + width_mult::Real = 1.0, inchannels::Integer = 3, + nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `t`: expansion factor of the block + + `c`: output channels of the block (will be scaled by width_mult) + + `n`: number of block repetitions + + `s`: kernel stride in the block except the first block of each stage + + `se`: whether to use a `squeeze_excite` layer in the block or not + + - `max_width`: maximum number of output channels before the fully connected + classification blocks + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, width_mult::Real = 1.0, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -8,13 +33,13 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, bias = false)) # building inverted residual blocks - for (t, c, n, s, r) in config + for (t, c, n, s, se) in config outplanes = _round_channels(c * width_mult, 8) for i in 1:n push!(layers, invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, stride = i == 1 ? s : 1, - reduction = r == 1 ? 4 : nothing)) + reduction = se == 1 ? 4 : nothing)) inplanes = outplanes end end @@ -25,6 +50,12 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) end +# config dict of inverted residual blocks for EfficientNetv2 +# t: expansion factor of the block +# c: output channels of the block (will be scaled by width_mult) +# n: number of block repetitions +# s: kernel stride in the block except the first block of each stage +# se: whether to use a `squeeze_excite` layer in the block or not const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE (1, 24, 2, 1, 0), (4, 48, 4, 2, 0), @@ -57,6 +88,21 @@ const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE (6, 512, 32, 2, 1), (6, 640, 8, 1, 1)]) +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `width_mult`: Controls the number of output feature maps in each block (with 1 + being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" struct EfficientNetv2 layers::Any end diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl index bc5ec3a2b..4f13d3695 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inception/inceptionv3.jl @@ -133,7 +133,8 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes """ -function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) +function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, + nclasses::Integer = 1000) backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., conv_norm((3, 3), 32, 32)..., conv_norm((3, 3), 32, 64; pad = 1)..., diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index 4964f3ca1..d4751352c 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -8,7 +8,7 @@ Create an Xception block. # Arguments - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `outchannels`: number of output channels. - `nrepeats`: number of repeats of depthwise separable convolution layers. - `stride`: stride by which to downsample the input. diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index b6d9fe8ee..b390a3f55 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -7,7 +7,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) + (with 1 being the default in the paper) - `configs`: A "list of tuples" configuration for each layer that details: @@ -63,7 +63,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 531a7c9da..fd5bc6691 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -17,7 +17,7 @@ Create a MobileNetv2 model. + `a`: The activation function used in the bottleneck layer - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) + (with 1 being the default in the paper) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes @@ -71,7 +71,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: Whether to load the pre-trained weights for ImageNet - `inchannels`: The number of input channels. diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 01acf9c54..82c5fb187 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -18,7 +18,7 @@ Create a MobileNetv3 model. + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4.) + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes @@ -45,7 +45,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end # building last layers headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : - max_width + max_width append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), @@ -100,10 +100,10 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `config`: :small or :large for the size of the model (see paper). - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: whether to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: the number of output classes See also [`Metalhead.mobilenetv3`](#). diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index afc446d14..1e6bb9fee 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -12,7 +12,7 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385 - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `reduction_factor`: the factor by which the input feature maps are reduced before + - `reduction_factor`: the factor by which the input feature maps are reduced before the first convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. @@ -117,9 +117,9 @@ end # Shortcut configurations for the ResNet models const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), - :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv), - :D => (downsample_pool, downsample_identity)) + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) # Stride for each block in the ResNet model function resnet_stride(stage_idx::Integer, block_idx::Integer) @@ -156,8 +156,8 @@ on how to use this function. convolution. This is an experimental variant from the `timm` library in Python that shows peformance improvements over the `:deep` stem in some cases. - - `inchannels`: The number of channels in the input. - - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + + - `inchannels`: number of input channels + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. @@ -323,9 +323,9 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`" - @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`" - @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`" + @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0" + @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0" + @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 8e054da82..b5dae6663 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -8,6 +8,7 @@ Creates a bottleneck block as described in the Res2Net paper. ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block @@ -33,17 +34,15 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, for _ in 1:max(1, scale - 1)] reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) - tuplify = if is_first - x -> tuple(x...) - else - x -> tuple(x[1], tuple(x[2:end]...)) - end - layers = [conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm, bias = false)..., - chunk$(; size = width, dims = 3), tuplify, reslayer, - conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm, bias = false)..., - attn_fn(outplanes)] + tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) + layers = [ + conv_norm((1, 1), inplanes => width * scale, activation; + norm_layer, revnorm, bias = false)..., + chunk$(; size = width, dims = 3), tuplify, reslayer, + conv_norm((1, 1), width * scale => outplanes, activation; + norm_layer, revnorm, bias = false)..., + attn_fn(outplanes), + ] return Chain(filter(!=(identity), layers)...) end @@ -86,6 +85,7 @@ Creates a Res2Net model with the specified depth, scale, and base width. ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `scale`: the number of feature groups in the block. See the @@ -125,6 +125,7 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `scale`: the number of feature groups in the block. See the diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index cdccddd4b..f935c3b93 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -46,7 +46,7 @@ The number of channels in outer 1x1 convolutions is the same. - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the Wide ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: The number of input channels. - - `nclasses`: the number of output classes + - `nclasses`: The number of output classes Advanced users who want more configuration options will be better served by using [`resnet`](#). """ diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 664202d15..20fc912a2 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -32,7 +32,8 @@ end function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) - layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, + base_width) if pretrain loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d")) diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index da074e57d..ff39921b0 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -22,8 +22,6 @@ struct SEResNet end @functor SEResNet -(m::SEResNet)(x) = m.layers(x) - function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(RESNET_CONFIGS)) @@ -35,6 +33,8 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = return SEResNet(layers) end +(m::SEResNet)(x) = m.layers(x) + backbone(m::SEResNet) = m.layers[1] classifier(m::SEResNet) = m.layers[2] @@ -65,12 +65,12 @@ struct SEResNeXt end @functor SEResNeXt -(m::SEResNeXt)(x) = m.layers(x) - function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, - base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) + base_width::Integer = 4, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) - layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, + base_width, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) @@ -78,5 +78,7 @@ function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer return SEResNeXt(layers) end +(m::SEResNeXt)(x) = m.layers(x) + backbone(m::SEResNeXt) = m.layers[1] classifier(m::SEResNeXt) = m.layers[2] diff --git a/src/layers/conv.jl b/src/layers/conv.jl index cdbfc472c..7c7fe20af 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -122,8 +122,10 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, - specify the keyword argument `expansion`, which calculates + - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, + specify the keyword argument `expansion`, which calculates the number of feature + maps in the hidden layer from the number of input feature maps as: + `hidden_planes = inplanes * expansion` - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index cb9b8378c..abdab4b44 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -11,7 +11,7 @@ patches. # Arguments - `imsize`: the size of the input image - - `inchannels`: the number of channels in the input. + - `inchannels`: number of input channels - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding - `norm_layer`: the normalization layer - by default the identity function but otherwise takes a From add4d4190af2249a5a5f76667e9429cc409231a0 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 08:57:04 +0530 Subject: [PATCH 04/34] Add tests --- .github/workflows/CI.yml | 3 ++- test/convnets.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5e14d6c49..37cda3263 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,7 +28,8 @@ jobs: suite: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - - '["EfficientNet"]' + - '"EfficientNet"' + - '"EfficientNetv2"' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' - '[r"Res2Net", r"Res2NeXt"]' diff --git a/test/convnets.jl b/test/convnets.jl index 6d7dab496..31d68d6d1 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -177,6 +177,20 @@ end end end +@testset "EfficientNetv2" begin + @testset for config in [:small, :medium, :large, :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) From 7d5639624549dab809c4cd71def25fb9185a6617 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 16:12:06 +0530 Subject: [PATCH 05/34] Fix Inception bug, and other misc. cleanup --- src/convnets/densenet.jl | 4 +- src/convnets/efficientnet/efficientnet.jl | 10 +- src/convnets/efficientnet/efficientnetv2.jl | 83 +++++++------- src/convnets/inception/inceptionresnetv2.jl | 8 +- src/convnets/inception/inceptionv3.jl | 82 ++++++------- src/convnets/inception/inceptionv4.jl | 120 +++++++++----------- src/convnets/inception/xception.jl | 8 +- src/convnets/mobilenet/mobilenetv1.jl | 12 +- src/convnets/mobilenet/mobilenetv2.jl | 18 +-- src/convnets/mobilenet/mobilenetv3.jl | 10 +- src/layers/Layers.jl | 4 +- src/layers/conv.jl | 63 ++++++---- 12 files changed, 211 insertions(+), 211 deletions(-) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index badb61a9e..75e1ffde1 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -10,8 +10,8 @@ Create a Densenet bottleneck layer - `outplanes`: number of output feature maps on bottleneck branch (and scaling factor for inner feature maps; see ref) """ -function dense_bottleneck(inplanes::Integer, outplanes::Integer) - inner_channels = 4 * outplanes +function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) + inner_channels = expansion * outplanes return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 518ff15f2..677c57529 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -17,10 +17,9 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). + `e`: expansion ratio + `i`: block input channels (will be scaled by global width scaling) + `o`: block output channels (will be scaled by global width scaling) + - `max_width`: The maximum number of feature maps in any layer of the network - `inchannels`: number of input channels - `nclasses`: number of output classes - - `max_width`: maximum number of output channels before the fully connected - classification blocks """ function efficientnet(scalings::NTuple{2, Real}, block_configs::AbstractVector{NTuple{6, Int}}; @@ -52,12 +51,7 @@ function efficientnet(scalings::NTuple{2, Real}, return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) end -# n: # of block repetitions -# k: kernel size k x k -# s: stride -# e: expantion ratio -# i: block input channels -# o: block output channels +# block configs for EfficientNet const EFFICIENTNET_BLOCK_CONFIGS = [ # (n, k, s, e, i, o) (1, 3, 1, 1, 32, 16), diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl index 40265bd2c..07f827095 100644 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -1,5 +1,5 @@ """ - efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, + efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, width_mult::Real = 1.0, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -14,16 +14,15 @@ Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + `c`: output channels of the block (will be scaled by width_mult) + `n`: number of block repetitions + `s`: kernel stride in the block except the first block of each stage - + `se`: whether to use a `squeeze_excite` layer in the block or not + + `r`: reduction factor of the squeeze-excite layer - - `max_width`: maximum number of output channels before the fully connected - classification blocks + - `max_width`: The maximum number of feature maps in any layer of the network - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper) - `inchannels`: number of input channels - `nclasses`: number of output classes """ -function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, +function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, width_mult::Real = 1.0, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer @@ -33,13 +32,12 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, bias = false)) # building inverted residual blocks - for (t, c, n, s, se) in config + for (t, c, n, s, reduction) in config outplanes = _round_channels(c * width_mult, 8) for i in 1:n push!(layers, invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, - stride = i == 1 ? s : 1, - reduction = se == 1 ? 4 : nothing)) + stride = i == 1 ? s : 1, reduction)) inplanes = outplanes end end @@ -50,43 +48,38 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) end -# config dict of inverted residual blocks for EfficientNetv2 -# t: expansion factor of the block -# c: output channels of the block (will be scaled by width_mult) -# n: number of block repetitions -# s: kernel stride in the block except the first block of each stage -# se: whether to use a `squeeze_excite` layer in the block or not -const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE - (1, 24, 2, 1, 0), - (4, 48, 4, 2, 0), - (4, 64, 4, 2, 0), - (4, 128, 6, 2, 1), - (6, 160, 9, 1, 1), - (6, 256, 15, 2, 1)], - :medium => [# t, c, n, s, SE - (1, 24, 3, 1, 0), - (4, 48, 5, 2, 0), - (4, 80, 5, 2, 0), - (4, 160, 7, 2, 1), - (6, 176, 14, 1, 1), - (6, 304, 18, 2, 1), - (6, 512, 5, 1, 1)], - :large => [# t, c, n, s, SE - (1, 32, 4, 1, 0), - (4, 64, 8, 2, 0), - (4, 96, 8, 2, 0), - (4, 192, 16, 2, 1), - (6, 256, 24, 1, 1), - (6, 512, 32, 2, 1), - (6, 640, 8, 1, 1)], - :xlarge => [# t, c, n, s, SE - (1, 32, 4, 1, 0), - (4, 64, 8, 2, 0), - (4, 96, 8, 2, 0), - (4, 192, 16, 2, 1), - (6, 256, 24, 1, 1), - (6, 512, 32, 2, 1), - (6, 640, 8, 1, 1)]) +# block configs for EfficientNetv2 +const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, r + (1, 24, 2, 1, nothing), + (4, 48, 4, 2, nothing), + (4, 64, 4, 2, nothing), + (4, 128, 6, 2, 4), + (6, 160, 9, 1, 4), + (6, 256, 15, 2, 4)], + :medium => [# t, c, n, s, r + (1, 24, 3, 1, nothing), + (4, 48, 5, 2, nothing), + (4, 80, 5, 2, nothing), + (4, 160, 7, 2, 4), + (6, 176, 14, 1, 4), + (6, 304, 18, 2, 4), + (6, 512, 5, 1, 4)], + :large => [# t, c, n, s, r + (1, 32, 4, 1, nothing), + (4, 64, 8, 2, nothing), + (4, 96, 8, 2, nothing), + (4, 192, 16, 2, 4), + (6, 256, 24, 1, 4), + (6, 512, 32, 2, 4), + (6, 640, 8, 1, 4)], + :xlarge => [# t, c, n, s, r + (1, 32, 4, 1, nothing), + (4, 64, 8, 2, nothing), + (4, 96, 8, 2, nothing), + (4, 192, 16, 2, 4), + (6, 256, 24, 1, 4), + (6, 512, 32, 2, 4), + (6, 640, 8, 1, 4)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index 54bc59952..a99973d09 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -34,8 +34,8 @@ end function block17(scale = 1.0f0) branch1 = Chain(conv_norm((1, 1), 1088, 192)...) branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((1, 7), 128, 160; pad = (0, 3))..., - conv_norm((7, 1), 160, 192; pad = (3, 0))...) + conv_norm((7, 1), 128, 160; pad = (3, 0))..., + conv_norm((1, 7), 160, 192; pad = (0, 3))...) branch3 = Chain(conv_norm((1, 1), 384, 1088)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation = relu)), +) @@ -56,8 +56,8 @@ end function block8(scale = 1.0f0; activation = identity) branch1 = Chain(conv_norm((1, 1), 2080, 192)...) branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((1, 3), 192, 224; pad = (0, 1))..., - conv_norm((3, 1), 224, 256; pad = (1, 0))...) + conv_norm((3, 1), 192, 224; pad = (1, 0))..., + conv_norm((1, 3), 224, 256; pad = (0, 1))...) branch3 = Chain(conv_norm((1, 1), 448, 2080)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation)), +) diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl index 4f13d3695..32fbbede5 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inception/inceptionv3.jl @@ -10,14 +10,14 @@ Create an Inception-v3 style-A module - `pool_proj`: the number of output feature maps for the pooling projection """ function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 64)...) + branch5x5 = Chain(basic_conv_bn((1, 1), inplanes, 48)..., + basic_conv_bn((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(basic_conv_bn((1, 1), inplanes, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, pool_proj)...) + basic_conv_bn((1, 1), inplanes, pool_proj)...) return Parallel(cat_channels, branch1x1, branch5x5, branch3x3, branch_pool) end @@ -33,10 +33,10 @@ Create an Inception-v3 style-B module - `inplanes`: number of input feature maps """ function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; stride = 2)...) + branch3x3_1 = Chain(basic_conv_bn((3, 3), inplanes, 384; stride = 2)...) + branch3x3_2 = Chain(basic_conv_bn((1, 1), inplanes, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3_1, branch3x3_2, branch_pool) @@ -55,17 +55,17 @@ Create an Inception-v3 style-C module - `n`: the "grid size" (kernel size) for the convolution layers """ function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 192)...) + branch7x7_1 = Chain(basic_conv_bn((1, 1), inplanes, inner_planes)..., + basic_conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + basic_conv_bn((1, n), inner_planes, 192; pad = (0, 3))...) + branch7x7_2 = Chain(basic_conv_bn((1, 1), inplanes, inner_planes)..., + basic_conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., + basic_conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + basic_conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., + basic_conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) + basic_conv_bn((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, branch7x7_1, branch7x7_2, branch_pool) end @@ -81,12 +81,12 @@ Create an Inception-v3 style-D module - `inplanes`: number of input feature maps """ function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((1, 7), 192, 192; pad = (0, 3))..., - conv_norm((7, 1), 192, 192; pad = (3, 0))..., - conv_norm((3, 3), 192, 192; stride = 2)...) + branch3x3 = Chain(basic_conv_bn((1, 1), inplanes, 192)..., + basic_conv_bn((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(basic_conv_bn((1, 1), inplanes, 192)..., + basic_conv_bn((7, 1), 192, 192; pad = (3, 0))..., + basic_conv_bn((1, 7), 192, 192; pad = (0, 3))..., + basic_conv_bn((3, 3), 192, 192; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3, branch7x7x3, branch_pool) @@ -103,16 +103,16 @@ Create an Inception-v3 style-E module - `inplanes`: number of input feature maps """ function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., - conv_norm((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 320)...) + branch3x3_1 = Chain(basic_conv_bn((1, 1), inplanes, 384)...) + branch3x3_1a = Chain(basic_conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_1b = Chain(basic_conv_bn((1, 3), 384, 384; pad = (0, 1))...) + branch3x3_2 = Chain(basic_conv_bn((1, 1), inplanes, 448)..., + basic_conv_bn((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(basic_conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_2b = Chain(basic_conv_bn((1, 3), 384, 384; pad = (0, 1))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) + basic_conv_bn((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, Chain(branch3x3_1, @@ -135,12 +135,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). """ function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_norm((1, 1), 64, 80)..., - conv_norm((3, 3), 80, 192)..., + basic_conv_bn((1, 1), 64, 80)..., + basic_conv_bn((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), inceptionv3_a(192, 32), inceptionv3_a(256, 64), diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl index 7f027da6e..b43f6bc1d 100644 --- a/src/convnets/inception/inceptionv4.jl +++ b/src/convnets/inception/inceptionv4.jl @@ -1,83 +1,86 @@ function mixed_3a() return Parallel(cat_channels, MaxPool((3, 3); stride = 2), - Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) + Chain(basic_conv_bn((3, 3), 64, 96; stride = 2)...)) end function mixed_4a() return Parallel(cat_channels, - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((3, 3), 64, 96)...), - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((1, 7), 64, 64; pad = (0, 3))..., - conv_norm((7, 1), 64, 64; pad = (3, 0))..., - conv_norm((3, 3), 64, 96)...)) + Chain(basic_conv_bn((1, 1), 160, 64)..., + basic_conv_bn((3, 3), 64, 96)...), + Chain(basic_conv_bn((1, 1), 160, 64)..., + basic_conv_bn((7, 1), 64, 64; pad = (3, 0))..., + basic_conv_bn((1, 7), 64, 64; pad = (0, 3))..., + basic_conv_bn((3, 3), 64, 96)...)) end function mixed_5a() return Parallel(cat_channels, - Chain(conv_norm((3, 3), 192, 192; stride = 2)...), + Chain(basic_conv_bn((3, 3), 192, 192; stride = 2)...), MaxPool((3, 3); stride = 2)) end function inceptionv4_a() - branch1 = Chain(conv_norm((1, 1), 384, 96)...) - branch2 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) + branch1 = Chain(basic_conv_bn((1, 1), 384, 96)...) + branch2 = Chain(basic_conv_bn((1, 1), 384, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(basic_conv_bn((1, 1), 384, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 384, 96)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function reduction_a() - branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 384, 192)..., - conv_norm((3, 3), 192, 224; pad = 1)..., - conv_norm((3, 3), 224, 256; stride = 2)...) + branch1 = Chain(basic_conv_bn((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 384, 192)..., + basic_conv_bn((3, 3), 192, 224; pad = 1)..., + basic_conv_bn((3, 3), 224, 256; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function inceptionv4_b() - branch1 = Chain(conv_norm((1, 1), 1024, 384)...) - branch2 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((1, 7), 192, 224; pad = (0, 3))..., - conv_norm((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((7, 1), 192, 192; pad = (0, 3))..., - conv_norm((1, 7), 192, 224; pad = (3, 0))..., - conv_norm((7, 1), 224, 224; pad = (0, 3))..., - conv_norm((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) + branch1 = Chain(basic_conv_bn((1, 1), 1024, 384)...) + branch2 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((7, 1), 192, 224; pad = (0, 3))..., + basic_conv_bn((1, 7), 224, 256; pad = (3, 0))...) + branch3 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((1, 7), 192, 192; pad = (3, 0))..., + basic_conv_bn((7, 1), 192, 224; pad = (0, 3))..., + basic_conv_bn((1, 7), 224, 224; pad = (3, 0))..., + basic_conv_bn((7, 1), 224, 256; pad = (0, 3))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 1024, 128)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function reduction_b() - branch1 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1024, 256)..., - conv_norm((1, 7), 256, 256; pad = (0, 3))..., - conv_norm((7, 1), 256, 320; pad = (3, 0))..., - conv_norm((3, 3), 320, 320; stride = 2)...) + branch1 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 1024, 256)..., + basic_conv_bn((7, 1), 256, 256; pad = (3, 0))..., + basic_conv_bn((1, 7), 256, 320; pad = (0, 3))..., + basic_conv_bn((3, 3), 320, 320; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function inceptionv4_c() - branch1 = Chain(conv_norm((1, 1), 1536, 256)...) - branch2 = Chain(conv_norm((1, 1), 1536, 384)..., + branch1 = Chain(basic_conv_bn((1, 1), 1536, 256)...) + branch2 = Chain(basic_conv_bn((1, 1), 1536, 384)..., Parallel(cat_channels, - Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_norm((1, 1), 1536, 384)..., - conv_norm((3, 1), 384, 448; pad = (1, 0))..., - conv_norm((1, 3), 448, 512; pad = (0, 1))..., + Chain(basic_conv_bn((3, 1), 384, 256; pad = (1, 0))...), + Chain(basic_conv_bn((1, 3), 384, 256; pad = (0, 1))...))) + branch3 = Chain(basic_conv_bn((1, 1), 1536, 384)..., + basic_conv_bn((1, 3), 384, 448; pad = (0, 1))..., + basic_conv_bn((3, 1), 448, 512; pad = (1, 0))..., Parallel(cat_channels, - Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) + Chain(basic_conv_bn((3, 1), 512, 256; pad = (1, 0))...), + Chain(basic_conv_bn((1, 3), 512, 256; pad = (0, 1))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 1536, 256)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end @@ -95,28 +98,15 @@ Create an Inceptionv4 model. """ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - mixed_3a(), - mixed_4a(), - mixed_5a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., + mixed_3a(), mixed_4a(), mixed_5a(), + [inceptionv4_a() for _ in 1:4]..., reduction_a(), # mixed_6a - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), + [inceptionv4_b() for _ in 1:7]..., reduction_b(), # mixed_7a - inceptionv4_c(), - inceptionv4_c(), - inceptionv4_c()) + [inceptionv4_c() for _ in 1:3]...) return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index d4751352c..14b5444d6 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -35,8 +35,8 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, - use_norm = (false, false))) + dwsep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, + use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -64,8 +64,8 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ xception_block(256, 728, 2; stride = 2), [xception_block(728, 728, 3) for _ in 1:8]..., xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...) + dwsep_conv_bn((3, 3), 1024, 1536; pad = 1)..., + dwsep_conv_bn((3, 3), 1536, 2048; pad = 1)...) return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index b390a3f55..db9dedbdb 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -22,16 +22,16 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] - for (dw, outch, stride, nrepeats) in config - outch = floor(Int, outch * width_mult) + for (dw, outchannels, stride, nrepeats) in config + outchannels = floor(Int, outchannels * width_mult) for _ in 1:nrepeats layer = dw ? - depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, + dwsep_conv_bn((3, 3), inchannels, outchannels, activation; + stride, pad = 1, bias = false) : + conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1, bias = false) append!(layers, layer) - inchannels = outch + inchannels = outchannels end end return Chain(Chain(layers...), create_classifier(inchannels, nclasses)) diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index fd5bc6691..39db8933b 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -1,7 +1,7 @@ """ mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv2 model. ([reference](https://arxiv.org/abs/1801.04381)). @@ -18,14 +18,15 @@ Create a MobileNetv2 model. - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper) - - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network + - `divisor`: The divisor used to round the number of feature maps in each block + - `dropout_rate`: rate of dropout in the classifier head + - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) - divisor = width_mult == 0.1 ? 4 : 8 + max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(32 * width_mult, divisor) layers = [] @@ -42,10 +43,9 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end end # building last layers - outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, divisor) : - max_width + outplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 82c5fb187..c3bca7a95 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -34,13 +34,13 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, bias = false)) explanes = 0 # building inverted residual blocks - for (k, t, c, r, a, s) in configs + for (k, t, c, reduction, activation, stride) in configs # inverted residual layers outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, - invertedresidual((k, k), inplanes, explanes, outplanes, a; - stride = s, reduction = r)) + invertedresidual((k, k), inplanes, explanes, outplanes, activation; + stride, reduction)) inplanes = outplanes end # building last layers @@ -56,7 +56,7 @@ end # Layer configurations for small and large models for MobileNetv3 const MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, SE, a, s + # k, t, c, r, a, s (3, 1, 16, 4, relu, 2), (3, 4.5, 24, nothing, relu, 2), (3, 3.67, 24, nothing, relu, 1), @@ -70,7 +70,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => [ (5, 6, 96, 4, hardswish, 1), ], :large => [ - # k, t, c, SE, a, s + # k, t, c, r, a, s (3, 1, 16, nothing, relu, 1), (3, 4, 24, nothing, relu, 2), (3, 3, 24, nothing, relu, 1), diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 04be476ff..ec48ab772 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -11,13 +11,15 @@ using MLUtils using PartialFunctions using Random +import Flux.testmode! + include("../utilities.jl") include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, depthwise_sep_conv_norm, invertedresidual +export conv_norm, basic_conv_bn, dwsep_conv_bn, invertedresidual include("drop.jl") export DropBlock, DropPath diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7c7fe20af..fcccae691 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -28,22 +28,34 @@ Create a convolution + batch normalization pair with activation. - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, + momentum::Union{Nothing, Float32} = nothing, preact::Bool = false, use_norm::Bool = true, kwargs...) + # handle momentum for BatchNorm + if norm_layer == BatchNorm + momentum = isnothing(momentum) ? 0.1f0 : momentum + norm_layer = (args...; kargs...) -> BatchNorm(args...; momentum, kargs...) + elseif norm_layer != BatchNorm && !isnothing(momentum) + error("momentum is only supported for BatchNorm") + end + # no normalization layer if !use_norm - if (preact || revnorm) + if preact || revnorm throw(ArgumentError("`preact` only supported with `use_norm = true`")) else + # early return if no norm layer is required return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] end end + # channels for norm layer and activation functions for both conv and norm if revnorm activations = (conv = activation, bn = identity) - bnplanes = inplanes + normplanes = inplanes else activations = (conv = identity, bn = activation) - bnplanes = outplanes + normplanes = outplanes end + # handle pre-activation if preact if revnorm throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time")) @@ -51,8 +63,9 @@ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activatio activations = (conv = activation, bn = identity) end end + # layers layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), - norm_layer(bnplanes, activations.bn)] + norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end @@ -62,8 +75,14 @@ function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = ide return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end +# conv + bn layer combination as used by the inception model family +function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) + return conv_norm(kernel_size, inplanes, outplanes, activation; eps = 1.0f-3, + bias = false, kwargs...) +end + """ - depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, + dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), @@ -95,15 +114,15 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), kwargs...) +function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, momentum::Float32 = 0.1f0, + revnorm::Bool = false, stride::Integer = 1, + use_norm::NTuple{2, Bool} = (true, true), kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; - norm_layer, revnorm, use_norm = use_norm[1], stride, + revnorm, use_norm = use_norm[1], stride, groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, - use_norm = use_norm[2])) + conv_norm((1, 1), inplanes, outplanes, activation; + revnorm, use_norm = use_norm[2])) end """ @@ -134,7 +153,8 @@ Create a basic inverted residual block for MobileNet variants """ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing) + reduction::Union{Nothing, Integer} = nothing, + momentum::Float32 = 0.1f0) @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 conv1 = inplanes == hidden_planes ? (identity,) : @@ -142,17 +162,18 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, norm_layer = BatchNorm) - invres = Chain(conv1..., - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad, groups = hidden_planes)..., - selayer, - conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) - return stride == 1 && inplanes == outplanes ? SkipConnection(invres, +) : invres + invres = [conv1..., + conv_norm(kernel_size, hidden_planes, hidden_planes, activation; + bias = false, stride, pad, groups = hidden_planes, momentum)..., + selayer, + conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...] + layers = Chain(filter(!=(identity), invres)...) + return stride == 1 && inplanes == outplanes ? SkipConnection(layers, +) : layers end function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) - return invertedresidual(kernel_size, inplanes, floor(Int, inplanes * expansion), + return invertedresidual(kernel_size, inplanes, round(Int, inplanes * expansion), outplanes, activation; stride, reduction) end From d3e4add13ceb4b53596c9e74de1bed54999bccf4 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 23:54:04 +0530 Subject: [PATCH 06/34] Refactor: `mbconv` instead of `invertedresidual` Also fix bug in EfficientNet models --- src/convnets/efficientnet/efficientnet.jl | 18 +++-- src/convnets/efficientnet/efficientnetv2.jl | 79 +++++++++--------- src/convnets/inception/inceptionresnetv2.jl | 82 +++++++++---------- src/convnets/mobilenet/mobilenetv2.jl | 4 +- src/convnets/mobilenet/mobilenetv3.jl | 10 +-- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 90 ++++++++++++--------- src/layers/selayers.jl | 36 ++++++--- test/convnets.jl | 2 +- 9 files changed, 178 insertions(+), 145 deletions(-) diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 677c57529..4f06d81b7 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -17,35 +17,37 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). + `e`: expansion ratio + `i`: block input channels (will be scaled by global width scaling) + `o`: block output channels (will be scaled by global width scaling) - - `max_width`: The maximum number of feature maps in any layer of the network - `inchannels`: number of input channels - `nclasses`: number of output classes """ function efficientnet(scalings::NTuple{2, Real}, block_configs::AbstractVector{NTuple{6, Int}}; - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) + inchannels::Integer = 3, nclasses::Integer = 1000) + # building first layer wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) outplanes = _round_channels(scalew(32), 8) stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, pad = SamePad()) + # building inverted residual blocks blocks = [] for (n, k, s, e, i, o) in block_configs inchannels = _round_channels(scalew(i), 8) + explanes = _round_channels(inchannels * e, 8) outplanes = _round_channels(scalew(o), 8) repeats = scaled(n) push!(blocks, - invertedresidual((k, k), inchannels, outplanes, swish; expansion = e, - stride = s, reduction = 4)) + mbconv((k, k), inchannels, explanes, outplanes, swish; + stride = s, reduction = 4)) for _ in 1:(repeats - 1) push!(blocks, - invertedresidual((k, k), outplanes, outplanes, swish; expansion = e, - stride = 1, reduction = 4)) + mbconv((k, k), outplanes, explanes, outplanes, swish; + stride = 1, reduction = 4)) end end - headplanes = _round_channels(max_width, 8) + # building last layers + headplanes = outplanes * 4 append!(blocks, conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl index 07f827095..afcba5dab 100644 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -32,12 +32,19 @@ function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 17 conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, bias = false)) # building inverted residual blocks - for (t, c, n, s, reduction) in config - outplanes = _round_channels(c * width_mult, 8) + for (t, inplanes, outplanes, n, s, reduction) in config + explanes = _round_channels(inplanes * t, 8) for i in 1:n - push!(layers, - invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, - stride = i == 1 ? s : 1, reduction)) + stride = i == 1 ? s : 1 + if isnothing(reduction) + push!(layers, + fused_mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) + else + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + push!(layers, + mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) + end inplanes = outplanes end end @@ -49,37 +56,37 @@ function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 17 end # block configs for EfficientNetv2 -const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, r - (1, 24, 2, 1, nothing), - (4, 48, 4, 2, nothing), - (4, 64, 4, 2, nothing), - (4, 128, 6, 2, 4), - (6, 160, 9, 1, 4), - (6, 256, 15, 2, 4)], - :medium => [# t, c, n, s, r - (1, 24, 3, 1, nothing), - (4, 48, 5, 2, nothing), - (4, 80, 5, 2, nothing), - (4, 160, 7, 2, 4), - (6, 176, 14, 1, 4), - (6, 304, 18, 2, 4), - (6, 512, 5, 1, 4)], - :large => [# t, c, n, s, r - (1, 32, 4, 1, nothing), - (4, 64, 8, 2, nothing), - (4, 96, 8, 2, nothing), - (4, 192, 16, 2, 4), - (6, 256, 24, 1, 4), - (6, 512, 32, 2, 4), - (6, 640, 8, 1, 4)], - :xlarge => [# t, c, n, s, r - (1, 32, 4, 1, nothing), - (4, 64, 8, 2, nothing), - (4, 96, 8, 2, nothing), - (4, 192, 16, 2, 4), - (6, 256, 24, 1, 4), - (6, 512, 32, 2, 4), - (6, 640, 8, 1, 4)]) +const EFFNETV2_CONFIGS = Dict(:small => [ + (1, 24, 24, 2, 1, nothing), + (4, 24, 48, 4, 2, nothing), + (4, 48, 64, 4, 2, nothing), + (4, 64, 128, 6, 2, 4), + (6, 128, 160, 9, 1, 4), + (6, 160, 256, 15, 2, 4)], + :medium => [ + (1, 24, 24, 3, 1, nothing), + (4, 24, 48, 5, 2, nothing), + (4, 48, 80, 5, 2, nothing), + (4, 80, 160, 7, 2, 4), + (6, 160, 176, 14, 1, 4), + (6, 176, 304, 18, 2, 4), + (6, 304, 512, 5, 1, 4)], + :large => [ + (1, 32, 32, 4, 1, nothing), + (4, 32, 64, 7, 2, nothing), + (4, 64, 96, 7, 2, nothing), + (4, 96, 192, 10, 2, 4), + (6, 192, 224, 19, 1, 4), + (6, 224, 384, 25, 2, 4), + (6, 384, 640, 7, 1, 4)], + :xlarge => [ + (1, 32, 32, 4, 1, nothing), + (4, 32, 64, 8, 2, nothing), + (4, 64, 96, 8, 2, nothing), + (4, 96, 192, 16, 2, 4), + (6, 192, 256, 24, 1, 4), + (6, 256, 512, 32, 2, 4), + (6, 512, 640, 8, 1, 4)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index a99973d09..7f462c0cf 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -1,64 +1,64 @@ function mixed_5b() - branch1 = Chain(conv_norm((1, 1), 192, 96)...) - branch2 = Chain(conv_norm((1, 1), 192, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_norm((1, 1), 192, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) + branch1 = Chain(basic_conv_bn((1, 1), 192, 96)...) + branch2 = Chain(basic_conv_bn((1, 1), 192, 48)..., + basic_conv_bn((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(basic_conv_bn((1, 1), 192, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), 192, 64)...) + basic_conv_bn((1, 1), 192, 64)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block35(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 320, 32)...) - branch2 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 48; pad = 1)..., - conv_norm((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_norm((1, 1), 128, 320)...) + branch1 = Chain(basic_conv_bn((1, 1), 320, 32)...) + branch2 = Chain(basic_conv_bn((1, 1), 320, 32)..., + basic_conv_bn((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(basic_conv_bn((1, 1), 320, 32)..., + basic_conv_bn((3, 3), 32, 48; pad = 1)..., + basic_conv_bn((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(basic_conv_bn((1, 1), 128, 320)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), branch4, inputscale(scale; activation = relu)), +) end function mixed_6a() - branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 320, 256)..., - conv_norm((3, 3), 256, 256; pad = 1)..., - conv_norm((3, 3), 256, 384; stride = 2)...) + branch1 = Chain(basic_conv_bn((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 320, 256)..., + basic_conv_bn((3, 3), 256, 256; pad = 1)..., + basic_conv_bn((3, 3), 256, 384; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function block17(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 1088, 192)...) - branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((7, 1), 128, 160; pad = (3, 0))..., - conv_norm((1, 7), 160, 192; pad = (0, 3))...) - branch3 = Chain(conv_norm((1, 1), 384, 1088)...) + branch1 = Chain(basic_conv_bn((1, 1), 1088, 192)...) + branch2 = Chain(basic_conv_bn((1, 1), 1088, 128)..., + basic_conv_bn((7, 1), 128, 160; pad = (3, 0))..., + basic_conv_bn((1, 7), 160, 192; pad = (0, 3))...) + branch3 = Chain(basic_conv_bn((1, 1), 384, 1088)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation = relu)), +) end function mixed_7a() - branch1 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; pad = 1)..., - conv_norm((3, 3), 288, 320; stride = 2)...) + branch1 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 288; pad = 1)..., + basic_conv_bn((3, 3), 288, 320; stride = 2)...) branch4 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_norm((1, 1), 2080, 192)...) - branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((3, 1), 192, 224; pad = (1, 0))..., - conv_norm((1, 3), 224, 256; pad = (0, 1))...) - branch3 = Chain(conv_norm((1, 1), 448, 2080)...) + branch1 = Chain(basic_conv_bn((1, 1), 2080, 192)...) + branch2 = Chain(basic_conv_bn((1, 1), 2080, 192)..., + basic_conv_bn((3, 1), 192, 224; pad = (1, 0))..., + basic_conv_bn((1, 3), 224, 256; pad = (0, 1))...) + branch3 = Chain(basic_conv_bn((1, 1), 448, 2080)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation)), +) end @@ -77,12 +77,12 @@ Creates an InceptionResNetv2 model. """ function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_norm((3, 3), 64, 80)..., - conv_norm((3, 3), 80, 192)..., + basic_conv_bn((3, 3), 64, 80)..., + basic_conv_bn((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), mixed_5b(), [block35(0.17f0) for _ in 1:10]..., @@ -91,7 +91,7 @@ function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, mixed_7a(), [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), - conv_norm((1, 1), 2080, 1536)...) + basic_conv_bn((1, 1), 2080, 1536)...) return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 39db8933b..7bc87bcb9 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -37,8 +37,8 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, outplanes = _round_channels(c * width_mult, divisor) for i in 1:n push!(layers, - invertedresidual((3, 3), inplanes, outplanes, a; expansion = t, - stride = i == 1 ? s : 1)) + mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a; + stride = i == 1 ? s : 1)) inplanes = outplanes end end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index c3bca7a95..68fe2f03b 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -24,8 +24,8 @@ Create a MobileNetv3 model. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] @@ -39,8 +39,8 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, - invertedresidual((k, k), inplanes, explanes, outplanes, activation; - stride, reduction)) + mbconv((k, k), inplanes, explanes, outplanes, activation; + stride, reduction)) inplanes = outplanes end # building last layers @@ -49,7 +49,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), - Dropout(0.2), + Dropout(dropout_rate), Dense(headplanes, nclasses)) return Chain(Chain(layers...), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index ec48ab772..72ace2c2c 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -19,7 +19,7 @@ include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, basic_conv_bn, dwsep_conv_bn, invertedresidual +export conv_norm, basic_conv_bn, dwsep_conv_bn, mbconv, fused_mbconv include("drop.jl") export DropBlock, DropPath diff --git a/src/layers/conv.jl b/src/layers/conv.jl index fcccae691..087082c8b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -29,15 +29,7 @@ Create a convolution + batch normalization pair with activation. """ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, - momentum::Union{Nothing, Float32} = nothing, preact::Bool = false, - use_norm::Bool = true, kwargs...) - # handle momentum for BatchNorm - if norm_layer == BatchNorm - momentum = isnothing(momentum) ? 0.1f0 : momentum - norm_layer = (args...; kargs...) -> BatchNorm(args...; momentum, kargs...) - elseif norm_layer != BatchNorm && !isnothing(momentum) - error("momentum is only supported for BatchNorm") - end + preact::Bool = false, use_norm::Bool = true, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -75,10 +67,11 @@ function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = ide return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end -# conv + bn layer combination as used by the inception model family +# conv + bn layer combination as used by the inception model family matching +# the default values used in TensorFlow function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) - return conv_norm(kernel_size, inplanes, outplanes, activation; eps = 1.0f-3, - bias = false, kwargs...) + return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, + eps = 1.0f-3, bias = false, kwargs...) end """ @@ -115,22 +108,22 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, momentum::Float32 = 0.1f0, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), kwargs...) - return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, revnorm, use_norm = use_norm[1], stride, groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; + conv_norm((1, 1), inplanes, outplanes, activation; eps, revnorm, use_norm = use_norm[2])) end """ - invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, + mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, outplanes::Integer, activation = relu; stride::Integer, reduction::Union{Nothing, Integer} = nothing) - invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, + mbconv(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) @@ -151,29 +144,48 @@ Create a basic inverted residual block for MobileNet variants - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)). """ -function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing, - momentum::Float32 = 0.1f0) +function mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" - pad = @. (kernel_size - 1) ÷ 2 - conv1 = inplanes == hidden_planes ? (identity,) : - conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false) - selayer = isnothing(reduction) ? identity : - squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, - norm_layer = BatchNorm) - invres = [conv1..., - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad, groups = hidden_planes, momentum)..., - selayer, - conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...] - layers = Chain(filter(!=(identity), invres)...) - return stride == 1 && inplanes == outplanes ? SkipConnection(layers, +) : layers + layers = [] + # expand + if inplanes != hidden_planes + append!(layers, + conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false, + norm_layer)) + end + # squeeze-excite layer + if !isnothing(reduction) + append!(layers, + squeeze_excite(hidden_planes, inplanes ÷ reduction; activation, + gate_activation = hardσ)) + end + # depthwise + append!(layers, + conv_norm(kernel_size, hidden_planes, hidden_planes, activation; bias = false, + norm_layer, stride, pad = SamePad(), groups = hidden_planes)) + # project + append!(layers, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)) + return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : + Chain(layers...) end -function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion::Real, - reduction::Union{Nothing, Integer} = nothing) - return invertedresidual(kernel_size, inplanes, round(Int, inplanes * expansion), - outplanes, activation; stride, reduction) +function fused_mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, + activation = relu; stride::Integer, norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2" + layers = [] + if explanes != inplanes + # fused expand + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride)) + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) + else + append!(layers, + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) + end + return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : + Chain(layers...) end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 0756225ba..db034d1af 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -1,34 +1,46 @@ """ + squeeze_excite(inplanes::Integer, squeeze_planes::Integer; + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) + squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, - activation = relu, gate_activation = sigmoid, norm_layer = identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + activation = relu, gate_activation = sigmoid, norm_layer = identity) -Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. +Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets. # Arguments - `inplanes`: The number of input feature maps - - `reduction`: The reduction factor for the number of hidden feature maps - - `rd_divisor`: The divisor for the number of hidden feature maps. + - `squeeze_planes`: The number of feature maps in the intermediate layers. Alternatively, + specify the keyword arguments `reduction` and `rd_divisior`, which determine the number + of feature maps in the intermediate layers from the number of input feature maps as: + `squeeze_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)`. + (See [`_round_channels`](#) for details. The default values are `reduction = 16` and + `rd_divisor = 8`.) - `activation`: The activation function for the first convolution layer - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer """ -function squeeze_excite(inplanes::Integer; reduction::Integer = 16, - rd_divisor::Integer = 8, activation = relu, - gate_activation = sigmoid, norm_layer = planes -> identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) +function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) layers = [AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes => rd_planes), - norm_layer(rd_planes), + Conv((1, 1), inplanes => squeeze_planes), + norm_layer(squeeze_planes), activation, - Conv((1, 1), rd_planes => inplanes), + Conv((1, 1), squeeze_planes => inplanes), norm_layer(inplanes), gate_activation] return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end +function squeeze_excite(inplanes::Integer; reduction::Integer = 16, rd_divisor::Integer = 8, + kwargs...) + return squeeze_excite(inplanes, _round_channels(inplanes ÷ reduction, rd_divisor, 0); + kwargs...) +end + """ effective_squeeze_excite(inplanes, gate_activation = sigmoid) diff --git a/test/convnets.jl b/test/convnets.jl index 31d68d6d1..e087ceb0e 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -161,7 +161,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] x = rand(Float32, r, r, 3, 1) From 6504f256ca77ae21fe1230f125f2961fd8825ed5 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 12 Aug 2022 13:51:02 +0530 Subject: [PATCH 07/34] Refactor EfficientNets --- src/Metalhead.jl | 21 +-- src/convnets/convmixer.jl | 4 +- src/convnets/densenet.jl | 8 +- src/convnets/efficientnet/efficientnet.jl | 113 ---------------- src/convnets/efficientnet/efficientnetv2.jl | 124 ------------------ src/convnets/efficientnets/core.jl | 78 +++++++++++ src/convnets/efficientnets/efficientnet.jl | 59 +++++++++ src/convnets/efficientnets/efficientnetv2.jl | 67 ++++++++++ .../{inception => inceptions}/googlenet.jl | 0 .../inceptionresnetv2.jl | 0 .../{inception => inceptions}/inceptionv3.jl | 0 .../{inception => inceptions}/inceptionv4.jl | 0 .../{inception => inceptions}/xception.jl | 10 +- .../{mobilenet => mobilenets}/mobilenetv1.jl | 5 +- .../{mobilenet => mobilenets}/mobilenetv2.jl | 4 +- .../{mobilenet => mobilenets}/mobilenetv3.jl | 5 +- src/convnets/resnets/core.jl | 24 ++-- src/convnets/resnets/res2net.jl | 6 +- src/convnets/vgg.jl | 2 +- src/layers/conv.jl | 104 ++++++++------- src/layers/drop.jl | 1 + test/convnets.jl | 2 +- 22 files changed, 300 insertions(+), 337 deletions(-) delete mode 100644 src/convnets/efficientnet/efficientnet.jl delete mode 100644 src/convnets/efficientnet/efficientnetv2.jl create mode 100644 src/convnets/efficientnets/core.jl create mode 100644 src/convnets/efficientnets/efficientnet.jl create mode 100644 src/convnets/efficientnets/efficientnetv2.jl rename src/convnets/{inception => inceptions}/googlenet.jl (100%) rename src/convnets/{inception => inceptions}/inceptionresnetv2.jl (100%) rename src/convnets/{inception => inceptions}/inceptionv3.jl (100%) rename src/convnets/{inception => inceptions}/inceptionv4.jl (100%) rename src/convnets/{inception => inceptions}/xception.jl (92%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv1.jl (95%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv2.jl (97%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv3.jl (98%) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index cad13afc9..6b0179f45 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -28,18 +28,19 @@ include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") include("convnets/resnets/res2net.jl") ## Inceptions -include("convnets/inception/googlenet.jl") -include("convnets/inception/inceptionv3.jl") -include("convnets/inception/inceptionv4.jl") -include("convnets/inception/inceptionresnetv2.jl") -include("convnets/inception/xception.jl") +include("convnets/inceptions/googlenet.jl") +include("convnets/inceptions/inceptionv3.jl") +include("convnets/inceptions/inceptionv4.jl") +include("convnets/inceptions/inceptionresnetv2.jl") +include("convnets/inceptions/xception.jl") ## EfficientNets -include("convnets/efficientnet/efficientnet.jl") -include("convnets/efficientnet/efficientnetv2.jl") +include("convnets/efficientnets/core.jl") +include("convnets/efficientnets/efficientnet.jl") +include("convnets/efficientnets/efficientnetv2.jl") ## MobileNets -include("convnets/mobilenet/mobilenetv1.jl") -include("convnets/mobilenet/mobilenetv2.jl") -include("convnets/mobilenet/mobilenetv3.jl") +include("convnets/mobilenets/mobilenetv1.jl") +include("convnets/mobilenets/mobilenetv2.jl") +include("convnets/mobilenets/mobilenetv3.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index c7dd058ff..1ca8487a9 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -1,5 +1,5 @@ """ - convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), + convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -16,7 +16,7 @@ Creates a ConvMixer model. - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ -function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), +function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 75e1ffde1..ca81b78ea 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -12,10 +12,10 @@ Create a Densenet bottleneck layer """ function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) inner_channels = expansion * outplanes - return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, + return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, - bias = false, revnorm = true)...), + revnorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes::Integer, outplanes::Integer) - return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; revnorm = true)..., MeanPool((2, 2))) end @@ -72,7 +72,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: nclasses::Integer = 1000) layers = [] append!(layers, - conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false)) + conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl deleted file mode 100644 index 4f06d81b7..000000000 --- a/src/convnets/efficientnet/efficientnet.jl +++ /dev/null @@ -1,113 +0,0 @@ -""" - efficientnet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - -# Arguments - - - `scalings`: global width and depth scaling (given as a tuple) - - - `block_configs`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `n`: number of block repetitions (will be scaled by global depth scaling) - + `k`: kernel size - + `s`: kernel stride - + `e`: expansion ratio - + `i`: block input channels (will be scaled by global width scaling) - + `o`: block output channels (will be scaled by global width scaling) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -function efficientnet(scalings::NTuple{2, Real}, - block_configs::AbstractVector{NTuple{6, Int}}; - inchannels::Integer = 3, nclasses::Integer = 1000) - # building first layer - wscale, dscale = scalings - scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) - scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - outplanes = _round_channels(scalew(32), 8) - stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, - pad = SamePad()) - # building inverted residual blocks - blocks = [] - for (n, k, s, e, i, o) in block_configs - inchannels = _round_channels(scalew(i), 8) - explanes = _round_channels(inchannels * e, 8) - outplanes = _round_channels(scalew(o), 8) - repeats = scaled(n) - push!(blocks, - mbconv((k, k), inchannels, explanes, outplanes, swish; - stride = s, reduction = 4)) - for _ in 1:(repeats - 1) - push!(blocks, - mbconv((k, k), outplanes, explanes, outplanes, swish; - stride = 1, reduction = 4)) - end - end - # building last layers - headplanes = outplanes * 4 - append!(blocks, - conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) - return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) -end - -# block configs for EfficientNet -const EFFICIENTNET_BLOCK_CONFIGS = [ - # (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), - (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320), -] - -# w: width scaling -# d: depth scaling -# r: image resolution -# Data is organised as (r, (w, d)) -const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6))) - -""" - EfficientNet(config::Symbol; pretrain::Bool = false) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). -See also [`efficientnet`](#). - -# Arguments - - - `config`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet -""" -struct EfficientNet - layers::Any -end -@functor EfficientNet - -function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) - _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS; - inchannels, nclasses) - if pretrain - loadpretrain!(model, string("efficientnet-", config)) - end - return model -end - -(m::EfficientNet)(x) = m.layers(x) - -backbone(m::EfficientNet) = m.layers[1] -classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl deleted file mode 100644 index afcba5dab..000000000 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ /dev/null @@ -1,124 +0,0 @@ -""" - efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, - width_mult::Real = 1.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - -Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - -# Arguments - - - `config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `t`: expansion factor of the block - + `c`: output channels of the block (will be scaled by width_mult) - + `n`: number of block repetitions - + `s`: kernel stride in the block except the first block of each stage - + `r`: reduction factor of the squeeze-excite layer - - - `max_width`: The maximum number of feature maps in any layer of the network - - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, - width_mult::Real = 1.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - # building first layer - inplanes = _round_channels(24 * width_mult, 8) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, - bias = false)) - # building inverted residual blocks - for (t, inplanes, outplanes, n, s, reduction) in config - explanes = _round_channels(inplanes * t, 8) - for i in 1:n - stride = i == 1 ? s : 1 - if isnothing(reduction) - push!(layers, - fused_mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) - else - inplanes = _round_channels(inplanes * width_mult, 8) - outplanes = _round_channels(outplanes * width_mult, 8) - push!(layers, - mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) - end - inplanes = outplanes - end - end - # building last layers - outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, 8) : - max_width - append!(layers, conv_norm((1, 1), inplanes, outplanes, swish; bias = false)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) -end - -# block configs for EfficientNetv2 -const EFFNETV2_CONFIGS = Dict(:small => [ - (1, 24, 24, 2, 1, nothing), - (4, 24, 48, 4, 2, nothing), - (4, 48, 64, 4, 2, nothing), - (4, 64, 128, 6, 2, 4), - (6, 128, 160, 9, 1, 4), - (6, 160, 256, 15, 2, 4)], - :medium => [ - (1, 24, 24, 3, 1, nothing), - (4, 24, 48, 5, 2, nothing), - (4, 48, 80, 5, 2, nothing), - (4, 80, 160, 7, 2, 4), - (6, 160, 176, 14, 1, 4), - (6, 176, 304, 18, 2, 4), - (6, 304, 512, 5, 1, 4)], - :large => [ - (1, 32, 32, 4, 1, nothing), - (4, 32, 64, 7, 2, nothing), - (4, 64, 96, 7, 2, nothing), - (4, 96, 192, 10, 2, 4), - (6, 192, 224, 19, 1, 4), - (6, 224, 384, 25, 2, 4), - (6, 384, 640, 7, 1, 4)], - :xlarge => [ - (1, 32, 32, 4, 1, nothing), - (4, 32, 64, 8, 2, nothing), - (4, 64, 96, 8, 2, nothing), - (4, 96, 192, 16, 2, 4), - (6, 192, 256, 24, 1, 4), - (6, 256, 512, 32, 2, 4), - (6, 512, 640, 8, 1, 4)]) - -""" - EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - -# Arguments - - - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `width_mult`: Controls the number of output feature maps in each block (with 1 - being the default in the paper) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -struct EfficientNetv2 - layers::Any -end -@functor EfficientNetv2 - -function EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnetv2(EFFNETV2_CONFIGS[config]; width_mult, inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("efficientnetv2")) - end - return EfficientNetv2(layers) -end - -(m::EfficientNetv2)(x) = m.layers(x) - -backbone(m::EfficientNetv2) = m.layers[1] -classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl new file mode 100644 index 000000000..1059cb538 --- /dev/null +++ b/src/convnets/efficientnets/core.jl @@ -0,0 +1,78 @@ +abstract type _MBConfig end + +struct MBConvConfig <: _MBConfig + kernel_size::Dims{2} + inplanes::Integer + outplanes::Integer + expansion::Number + stride::Integer + nrepeats::Integer +end +function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, + expansion::Number, stride::Integer, nrepeats::Integer, + width_mult::Number = 1, depth_mult::Number = 1) + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + nrepeats = ceil(Int, nrepeats * depth_mult) + return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, + stride, nrepeats) +end + +function efficientnetblock(m::MBConvConfig, norm_layer) + layers = [] + explanes = _round_channels(m.inplanes * m.expansion, 8) + push!(layers, + mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer, + stride = m.stride, reduction = 4)) + explanes = _round_channels(m.outplanes * m.expansion, 8) + append!(layers, + [mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer, + stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)]) + return Chain(layers...) +end + +struct FusedMBConvConfig <: _MBConfig + kernel_size::Dims{2} + inplanes::Integer + outplanes::Integer + expansion::Number + stride::Integer + nrepeats::Integer +end +function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, + expansion::Number, stride::Integer, nrepeats::Integer) + return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, + stride, nrepeats) +end + +function efficientnetblock(m::FusedMBConvConfig, norm_layer) + layers = [] + explanes = _round_channels(m.inplanes * m.expansion, 8) + push!(layers, + fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; + norm_layer, stride = m.stride)) + explanes = _round_channels(m.outplanes * m.expansion, 8) + append!(layers, + [fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; + norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)]) + return Chain(layers...) +end + +function efficientnet(block_configs::AbstractVector{<:_MBConfig}; + headplanes::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) + layers = [] + # stem of the model + append!(layers, + conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer, + stride = 2, pad = SamePad())) + # building inverted residual blocks + append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs]) + # building last layers + outplanes = block_configs[end].outplanes + headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes + append!(layers, + conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad())) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) +end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl new file mode 100644 index 000000000..0bb481dda --- /dev/null +++ b/src/convnets/efficientnets/efficientnet.jl @@ -0,0 +1,59 @@ +# block configs for EfficientNet +const EFFICIENTNET_BLOCK_CONFIGS = [ + # k, i, o, e, s, n + (3, 32, 16, 1, 1, 1), + (3, 16, 24, 6, 2, 2), + (5, 24, 40, 6, 2, 2), + (3, 40, 80, 6, 2, 3), + (5, 80, 112, 6, 1, 3), + (5, 112, 192, 6, 2, 4), + (3, 192, 320, 6, 1, 1), +] + +# Data is organised as (r, (w, d)) +# r: image resolution +# w: width scaling +# d: depth scaling +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) + +""" + EfficientNet(config::Symbol; pretrain::Bool = false) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). +See also [`efficientnet`](#). + +# Arguments + + - `config`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet +""" +struct EfficientNet + layers::Any +end +@functor EfficientNet + +function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + cfg_fn = (args...) -> MBConvConfig(args..., EFFICIENTNET_GLOBAL_CONFIGS[config][2]...) + block_configs = [cfg_fn(args...) for args in EFFICIENTNET_BLOCK_CONFIGS] + layers = efficientnet(block_configs; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnet-", config)) + end + return EfficientNet(layers) +end + +(m::EfficientNet)(x) = m.layers(x) + +backbone(m::EfficientNet) = m.layers[1] +classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl new file mode 100644 index 000000000..d2d6a3222 --- /dev/null +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -0,0 +1,67 @@ +# block configs for EfficientNetv2 +const EFFNETV2_CONFIGS = Dict(:small => [ + FusedMBConvConfig(3, 24, 24, 1, 1, 2), + FusedMBConvConfig(3, 24, 48, 4, 2, 4), + FusedMBConvConfig(3, 48, 64, 4, 2, 4), + MBConvConfig(3, 64, 128, 4, 2, 6), + MBConvConfig(3, 128, 160, 6, 1, 9), + MBConvConfig(3, 160, 256, 6, 2, 15)], + :medium => [ + FusedMBConvConfig(3, 24, 24, 1, 1, 3), + FusedMBConvConfig(3, 24, 48, 4, 2, 5), + FusedMBConvConfig(3, 48, 80, 4, 2, 5), + MBConvConfig(3, 80, 160, 4, 2, 7), + MBConvConfig(3, 160, 176, 6, 1, 14), + MBConvConfig(3, 176, 304, 6, 2, 18), + MBConvConfig(3, 304, 512, 6, 1, 5)], + :large => [ + FusedMBConvConfig(3, 32, 32, 1, 1, 4), + FusedMBConvConfig(3, 32, 64, 4, 2, 7), + FusedMBConvConfig(3, 64, 96, 4, 2, 7), + MBConvConfig(3, 96, 192, 4, 2, 10), + MBConvConfig(3, 192, 224, 6, 1, 19), + MBConvConfig(3, 224, 384, 6, 2, 25), + MBConvConfig(3, 384, 640, 6, 1, 7)], + :xlarge => [ + FusedMBConvConfig(3, 32, 32, 1, 1, 4), + FusedMBConvConfig(3, 32, 64, 4, 2, 8), + FusedMBConvConfig(3, 64, 96, 4, 2, 8), + MBConvConfig(3, 96, 192, 4, 2, 16), + MBConvConfig(3, 192, 224, 6, 1, 24), + MBConvConfig(3, 384, 512, 6, 2, 32), + MBConvConfig(3, 512, 768, 6, 1, 8)]) + +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `width_mult`: Controls the number of output feature maps in each block (with 1 + being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" +struct EfficientNetv2 + layers::Any +end +@functor EfficientNetv2 + +function EfficientNetv2(config::Symbol; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) + layers = efficientnet(EFFNETV2_CONFIGS[config]; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnetv2")) + end + return EfficientNetv2(layers) +end + +(m::EfficientNetv2)(x) = m.layers(x) + +backbone(m::EfficientNetv2) = m.layers[1] +classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/inception/googlenet.jl b/src/convnets/inceptions/googlenet.jl similarity index 100% rename from src/convnets/inception/googlenet.jl rename to src/convnets/inceptions/googlenet.jl diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl similarity index 100% rename from src/convnets/inception/inceptionresnetv2.jl rename to src/convnets/inceptions/inceptionresnetv2.jl diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl similarity index 100% rename from src/convnets/inception/inceptionv3.jl rename to src/convnets/inceptions/inceptionv3.jl diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl similarity index 100% rename from src/convnets/inception/inceptionv4.jl rename to src/convnets/inceptions/inceptionv4.jl diff --git a/src/convnets/inception/xception.jl b/src/convnets/inceptions/xception.jl similarity index 92% rename from src/convnets/inception/xception.jl rename to src/convnets/inceptions/xception.jl index 14b5444d6..33222e7be 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -19,8 +19,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int stride::Integer = 1, start_with_relu::Bool = true, grow_at_start::Bool = true) if outchannels != inchannels || stride != 1 - skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride) else skip = [identity] end @@ -35,8 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - dwsep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, - use_norm = (false, false))) + dwsep_conv_bn((3, 3), inc, outc; pad = 1, use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -57,8 +55,8 @@ Creates an Xception model. - `nclasses`: the number of output classes. """ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_norm((3, 3), 32, 64; bias = false)..., + backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 64)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), xception_block(128, 256, 2; stride = 2), xception_block(256, 728, 2; stride = 2), diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl similarity index 95% rename from src/convnets/mobilenet/mobilenetv1.jl rename to src/convnets/mobilenets/mobilenetv1.jl index db9dedbdb..caa899a53 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -27,9 +27,8 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati for _ in 1:nrepeats layer = dw ? dwsep_conv_bn((3, 3), inchannels, outchannels, activation; - stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1, - bias = false) + stride, pad = 1) : + conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1) append!(layers, layer) inchannels = outchannels end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl similarity index 97% rename from src/convnets/mobilenet/mobilenetv2.jl rename to src/convnets/mobilenets/mobilenetv2.jl index 7bc87bcb9..232286309 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -31,7 +31,7 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = _round_channels(32 * width_mult, divisor) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes; bias = false, pad = 1, stride = 2)) + conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, divisor) @@ -44,7 +44,7 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)) + append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl similarity index 98% rename from src/convnets/mobilenet/mobilenetv3.jl rename to src/convnets/mobilenets/mobilenetv3.jl index 68fe2f03b..78c55e144 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -30,8 +30,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1, - bias = false)) + conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) explanes = 0 # building inverted residual blocks for (k, t, c, reduction, activation, stride) in configs @@ -46,7 +45,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, # building last layers headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : max_width - append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) + append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), Dropout(dropout_rate), diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 1e6bb9fee..35bb34fc4 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -29,9 +29,9 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, first_planes = planes ÷ reduction_factor outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, - stride, pad = 1, bias = false) + stride, pad = 1) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, - pad = 1, bias = false) + pad = 1) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -72,12 +72,10 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, - bias = false) + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, - stride, pad = 1, groups = cardinality, bias = false) - conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm, - bias = false) + stride, pad = 1, groups = cardinality) + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -87,7 +85,7 @@ end function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false) return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - pad = SamePad(), stride, bias = false)...) + pad = SamePad(), stride)...) end # Downsample layer using max pooling @@ -95,8 +93,7 @@ function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer norm_layer = BatchNorm, revnorm::Bool = false) pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - bias = false)...) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -178,9 +175,9 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; - norm_layer, revnorm, stride = 2, pad = 1, bias = false)..., + norm_layer, revnorm, stride = 2, pad = 1)..., conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; - norm_layer, pad = 1, bias = false)..., + norm_layer, pad = 1)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) @@ -189,8 +186,7 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - revnorm, - stride = 2, pad = 1, bias = false)...) : + revnorm, stride = 2, pad = 1)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool) end diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index b5dae6663..e308e1125 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -30,17 +30,17 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, outplanes = planes * 4 pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride, - pad = 1, groups = cardinality, bias = false)...) + pad = 1, groups = cardinality)...) for _ in 1:max(1, scale - 1)] reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) layers = [ conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm, bias = false)..., + norm_layer, revnorm)..., chunk$(; size = width, dims = 3), tuplify, reslayer, conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm, bias = false)..., + norm_layer, revnorm)..., attn_fn(outplanes), ] return Chain(filter(!=(identity), layers)...) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index de232d9a3..163c13b68 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -17,7 +17,7 @@ function vgg_block(ifilters::Integer, ofilters::Integer, depth::Integer, batchno layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_norm(k, ifilters, ofilters; pad = p, bias = false)) + append!(layers, conv_norm(k, ifilters, ofilters; pad = p)) else push!(layers, Conv(k, ifilters => ofilters, relu; pad = p)) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 087082c8b..78300b0d0 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,10 +1,11 @@ """ - conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, - use_norm::Bool = true, stride::Integer = 1, pad::Integer = 0, - dilation::Integer = 1, groups::Integer = 1, [bias, weight, init]) + conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + stride::Integer = 1, pad::Integer = 0, dilation::Integer = 1, + groups::Integer = 1, [bias, weight, init]) - conv_norm(kernel_size, inplanes => outplanes, activation = identity; + conv_norm(kernel_size::Dims{2}, inplanes => outplanes, activation = identity; kwargs...) Create a convolution + batch normalization pair with activation. @@ -25,11 +26,14 @@ Create a convolution + batch normalization pair with activation. - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `bias`: bias for the convolution kernel. This is set to `false` by default if + `use_norm = true`. + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, - preact::Bool = false, use_norm::Bool = true, kwargs...) +function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + bias = !use_norm, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -56,30 +60,30 @@ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activatio end end # layers - layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), + layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; - kwargs...) +function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, + activation = identity; kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end # conv + bn layer combination as used by the inception model family matching # the default values used in TensorFlow -function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) +function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu; + kwargs...) return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, - eps = 1.0f-3, bias = false, kwargs...) + eps = 1.0f-3, kwargs...) end """ - dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) + dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: @@ -102,31 +106,32 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `revnorm`: set to `true` to place the batch norm before the convolution - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and second convolution + - `bias`: a tuple of two booleans to specify whether to use bias for the first and second + convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and + `use_norm[1] == true`. - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, +function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, + outplanes::Integer, activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), kwargs...) + use_norm::NTuple{2, Bool} = (true, true), + bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, - revnorm, use_norm = use_norm[1], stride, + revnorm, use_norm = use_norm[1], stride, bias = bias[1], groups = inplanes, kwargs...), conv_norm((1, 1), inplanes, outplanes, activation; eps, - revnorm, use_norm = use_norm[2])) + revnorm, use_norm = use_norm[2], bias = bias[2])) end +# TODO add support for stochastic depth to mbconv and fused_mbconv """ - mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, + mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, reduction::Union{Nothing, Integer} = nothing) - mbconv(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion::Real, - reduction::Union{Nothing, Integer} = nothing) - Create a basic inverted residual block for MobileNet variants ([reference](https://arxiv.org/abs/1905.02244)). @@ -134,46 +139,43 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, - specify the keyword argument `expansion`, which calculates the number of feature - maps in the hidden layer from the number of input feature maps as: - `hidden_planes = inplanes * expansion` + - `explanes`: The number of feature maps in the hidden layer - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)). + in a squeeze and excite layer (see [`squeeze_excite`](#)) """ -function mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing, +function mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] # expand - if inplanes != hidden_planes + if inplanes != explanes append!(layers, - conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false, - norm_layer)) + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end + # depthwise + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) - append!(layers, - squeeze_excite(hidden_planes, inplanes ÷ reduction; activation, - gate_activation = hardσ)) + push!(layers, + squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, + gate_activation = hardσ)) end - # depthwise - append!(layers, - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; bias = false, - norm_layer, stride, pad = SamePad(), groups = hidden_planes)) # project - append!(layers, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)) + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : Chain(layers...) end -function fused_mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, norm_layer = BatchNorm) +function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] if explanes != inplanes diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 31c06c07a..b252584fe 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -116,6 +116,7 @@ function Base.show(io::IO, d::DropBlock) return print(io, ")") end +# TODO look into "row" mode for stochastic depth """ DropPath(p; [rng = rng_from_array(x)]) diff --git a/test/convnets.jl b/test/convnets.jl index e087ceb0e..95d9cfdf6 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -178,7 +178,7 @@ end end @testset "EfficientNetv2" begin - @testset for config in [:small, :medium, :large, :xlarge] + @testset for config in [:small, :medium, :large] # :xlarge] m = EfficientNetv2(config) @test size(m(x_224)) == (1000, 1) if (EfficientNetv2, config) in PRETRAINED_MODELS From 3e7e5f7fa47956fd70ae5d27f287c946f9c74029 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 12 Aug 2022 13:51:02 +0530 Subject: [PATCH 08/34] Refactor EfficientNets --- src/Metalhead.jl | 21 +-- src/convnets/convmixer.jl | 4 +- src/convnets/densenet.jl | 8 +- src/convnets/efficientnet/efficientnet.jl | 113 ---------------- src/convnets/efficientnet/efficientnetv2.jl | 124 ------------------ src/convnets/efficientnets/core.jl | 78 +++++++++++ src/convnets/efficientnets/efficientnet.jl | 59 +++++++++ src/convnets/efficientnets/efficientnetv2.jl | 67 ++++++++++ .../{inception => inceptions}/googlenet.jl | 0 .../inceptionresnetv2.jl | 0 .../{inception => inceptions}/inceptionv3.jl | 0 .../{inception => inceptions}/inceptionv4.jl | 0 .../{inception => inceptions}/xception.jl | 10 +- .../{mobilenet => mobilenets}/mobilenetv1.jl | 5 +- .../{mobilenet => mobilenets}/mobilenetv2.jl | 4 +- .../{mobilenet => mobilenets}/mobilenetv3.jl | 5 +- src/convnets/resnets/core.jl | 24 ++-- src/convnets/resnets/res2net.jl | 6 +- src/convnets/vgg.jl | 2 +- src/layers/conv.jl | 104 ++++++++------- src/layers/drop.jl | 1 + test/convnets.jl | 4 +- 22 files changed, 301 insertions(+), 338 deletions(-) delete mode 100644 src/convnets/efficientnet/efficientnet.jl delete mode 100644 src/convnets/efficientnet/efficientnetv2.jl create mode 100644 src/convnets/efficientnets/core.jl create mode 100644 src/convnets/efficientnets/efficientnet.jl create mode 100644 src/convnets/efficientnets/efficientnetv2.jl rename src/convnets/{inception => inceptions}/googlenet.jl (100%) rename src/convnets/{inception => inceptions}/inceptionresnetv2.jl (100%) rename src/convnets/{inception => inceptions}/inceptionv3.jl (100%) rename src/convnets/{inception => inceptions}/inceptionv4.jl (100%) rename src/convnets/{inception => inceptions}/xception.jl (92%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv1.jl (95%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv2.jl (97%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv3.jl (98%) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index cad13afc9..6b0179f45 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -28,18 +28,19 @@ include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") include("convnets/resnets/res2net.jl") ## Inceptions -include("convnets/inception/googlenet.jl") -include("convnets/inception/inceptionv3.jl") -include("convnets/inception/inceptionv4.jl") -include("convnets/inception/inceptionresnetv2.jl") -include("convnets/inception/xception.jl") +include("convnets/inceptions/googlenet.jl") +include("convnets/inceptions/inceptionv3.jl") +include("convnets/inceptions/inceptionv4.jl") +include("convnets/inceptions/inceptionresnetv2.jl") +include("convnets/inceptions/xception.jl") ## EfficientNets -include("convnets/efficientnet/efficientnet.jl") -include("convnets/efficientnet/efficientnetv2.jl") +include("convnets/efficientnets/core.jl") +include("convnets/efficientnets/efficientnet.jl") +include("convnets/efficientnets/efficientnetv2.jl") ## MobileNets -include("convnets/mobilenet/mobilenetv1.jl") -include("convnets/mobilenet/mobilenetv2.jl") -include("convnets/mobilenet/mobilenetv3.jl") +include("convnets/mobilenets/mobilenetv1.jl") +include("convnets/mobilenets/mobilenetv2.jl") +include("convnets/mobilenets/mobilenetv3.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index c7dd058ff..1ca8487a9 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -1,5 +1,5 @@ """ - convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), + convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -16,7 +16,7 @@ Creates a ConvMixer model. - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ -function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), +function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 75e1ffde1..ca81b78ea 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -12,10 +12,10 @@ Create a Densenet bottleneck layer """ function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) inner_channels = expansion * outplanes - return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, + return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, - bias = false, revnorm = true)...), + revnorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes::Integer, outplanes::Integer) - return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; revnorm = true)..., MeanPool((2, 2))) end @@ -72,7 +72,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: nclasses::Integer = 1000) layers = [] append!(layers, - conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false)) + conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl deleted file mode 100644 index 4f06d81b7..000000000 --- a/src/convnets/efficientnet/efficientnet.jl +++ /dev/null @@ -1,113 +0,0 @@ -""" - efficientnet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - -# Arguments - - - `scalings`: global width and depth scaling (given as a tuple) - - - `block_configs`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `n`: number of block repetitions (will be scaled by global depth scaling) - + `k`: kernel size - + `s`: kernel stride - + `e`: expansion ratio - + `i`: block input channels (will be scaled by global width scaling) - + `o`: block output channels (will be scaled by global width scaling) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -function efficientnet(scalings::NTuple{2, Real}, - block_configs::AbstractVector{NTuple{6, Int}}; - inchannels::Integer = 3, nclasses::Integer = 1000) - # building first layer - wscale, dscale = scalings - scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) - scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - outplanes = _round_channels(scalew(32), 8) - stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, - pad = SamePad()) - # building inverted residual blocks - blocks = [] - for (n, k, s, e, i, o) in block_configs - inchannels = _round_channels(scalew(i), 8) - explanes = _round_channels(inchannels * e, 8) - outplanes = _round_channels(scalew(o), 8) - repeats = scaled(n) - push!(blocks, - mbconv((k, k), inchannels, explanes, outplanes, swish; - stride = s, reduction = 4)) - for _ in 1:(repeats - 1) - push!(blocks, - mbconv((k, k), outplanes, explanes, outplanes, swish; - stride = 1, reduction = 4)) - end - end - # building last layers - headplanes = outplanes * 4 - append!(blocks, - conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) - return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) -end - -# block configs for EfficientNet -const EFFICIENTNET_BLOCK_CONFIGS = [ - # (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), - (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320), -] - -# w: width scaling -# d: depth scaling -# r: image resolution -# Data is organised as (r, (w, d)) -const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6))) - -""" - EfficientNet(config::Symbol; pretrain::Bool = false) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). -See also [`efficientnet`](#). - -# Arguments - - - `config`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet -""" -struct EfficientNet - layers::Any -end -@functor EfficientNet - -function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) - _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS; - inchannels, nclasses) - if pretrain - loadpretrain!(model, string("efficientnet-", config)) - end - return model -end - -(m::EfficientNet)(x) = m.layers(x) - -backbone(m::EfficientNet) = m.layers[1] -classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl deleted file mode 100644 index afcba5dab..000000000 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ /dev/null @@ -1,124 +0,0 @@ -""" - efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, - width_mult::Real = 1.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - -Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - -# Arguments - - - `config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `t`: expansion factor of the block - + `c`: output channels of the block (will be scaled by width_mult) - + `n`: number of block repetitions - + `s`: kernel stride in the block except the first block of each stage - + `r`: reduction factor of the squeeze-excite layer - - - `max_width`: The maximum number of feature maps in any layer of the network - - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, - width_mult::Real = 1.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - # building first layer - inplanes = _round_channels(24 * width_mult, 8) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, - bias = false)) - # building inverted residual blocks - for (t, inplanes, outplanes, n, s, reduction) in config - explanes = _round_channels(inplanes * t, 8) - for i in 1:n - stride = i == 1 ? s : 1 - if isnothing(reduction) - push!(layers, - fused_mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) - else - inplanes = _round_channels(inplanes * width_mult, 8) - outplanes = _round_channels(outplanes * width_mult, 8) - push!(layers, - mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) - end - inplanes = outplanes - end - end - # building last layers - outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, 8) : - max_width - append!(layers, conv_norm((1, 1), inplanes, outplanes, swish; bias = false)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) -end - -# block configs for EfficientNetv2 -const EFFNETV2_CONFIGS = Dict(:small => [ - (1, 24, 24, 2, 1, nothing), - (4, 24, 48, 4, 2, nothing), - (4, 48, 64, 4, 2, nothing), - (4, 64, 128, 6, 2, 4), - (6, 128, 160, 9, 1, 4), - (6, 160, 256, 15, 2, 4)], - :medium => [ - (1, 24, 24, 3, 1, nothing), - (4, 24, 48, 5, 2, nothing), - (4, 48, 80, 5, 2, nothing), - (4, 80, 160, 7, 2, 4), - (6, 160, 176, 14, 1, 4), - (6, 176, 304, 18, 2, 4), - (6, 304, 512, 5, 1, 4)], - :large => [ - (1, 32, 32, 4, 1, nothing), - (4, 32, 64, 7, 2, nothing), - (4, 64, 96, 7, 2, nothing), - (4, 96, 192, 10, 2, 4), - (6, 192, 224, 19, 1, 4), - (6, 224, 384, 25, 2, 4), - (6, 384, 640, 7, 1, 4)], - :xlarge => [ - (1, 32, 32, 4, 1, nothing), - (4, 32, 64, 8, 2, nothing), - (4, 64, 96, 8, 2, nothing), - (4, 96, 192, 16, 2, 4), - (6, 192, 256, 24, 1, 4), - (6, 256, 512, 32, 2, 4), - (6, 512, 640, 8, 1, 4)]) - -""" - EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - -# Arguments - - - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `width_mult`: Controls the number of output feature maps in each block (with 1 - being the default in the paper) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -struct EfficientNetv2 - layers::Any -end -@functor EfficientNetv2 - -function EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnetv2(EFFNETV2_CONFIGS[config]; width_mult, inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("efficientnetv2")) - end - return EfficientNetv2(layers) -end - -(m::EfficientNetv2)(x) = m.layers(x) - -backbone(m::EfficientNetv2) = m.layers[1] -classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl new file mode 100644 index 000000000..1059cb538 --- /dev/null +++ b/src/convnets/efficientnets/core.jl @@ -0,0 +1,78 @@ +abstract type _MBConfig end + +struct MBConvConfig <: _MBConfig + kernel_size::Dims{2} + inplanes::Integer + outplanes::Integer + expansion::Number + stride::Integer + nrepeats::Integer +end +function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, + expansion::Number, stride::Integer, nrepeats::Integer, + width_mult::Number = 1, depth_mult::Number = 1) + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + nrepeats = ceil(Int, nrepeats * depth_mult) + return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, + stride, nrepeats) +end + +function efficientnetblock(m::MBConvConfig, norm_layer) + layers = [] + explanes = _round_channels(m.inplanes * m.expansion, 8) + push!(layers, + mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer, + stride = m.stride, reduction = 4)) + explanes = _round_channels(m.outplanes * m.expansion, 8) + append!(layers, + [mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer, + stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)]) + return Chain(layers...) +end + +struct FusedMBConvConfig <: _MBConfig + kernel_size::Dims{2} + inplanes::Integer + outplanes::Integer + expansion::Number + stride::Integer + nrepeats::Integer +end +function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, + expansion::Number, stride::Integer, nrepeats::Integer) + return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, + stride, nrepeats) +end + +function efficientnetblock(m::FusedMBConvConfig, norm_layer) + layers = [] + explanes = _round_channels(m.inplanes * m.expansion, 8) + push!(layers, + fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; + norm_layer, stride = m.stride)) + explanes = _round_channels(m.outplanes * m.expansion, 8) + append!(layers, + [fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; + norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)]) + return Chain(layers...) +end + +function efficientnet(block_configs::AbstractVector{<:_MBConfig}; + headplanes::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) + layers = [] + # stem of the model + append!(layers, + conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer, + stride = 2, pad = SamePad())) + # building inverted residual blocks + append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs]) + # building last layers + outplanes = block_configs[end].outplanes + headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes + append!(layers, + conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad())) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) +end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl new file mode 100644 index 000000000..0bb481dda --- /dev/null +++ b/src/convnets/efficientnets/efficientnet.jl @@ -0,0 +1,59 @@ +# block configs for EfficientNet +const EFFICIENTNET_BLOCK_CONFIGS = [ + # k, i, o, e, s, n + (3, 32, 16, 1, 1, 1), + (3, 16, 24, 6, 2, 2), + (5, 24, 40, 6, 2, 2), + (3, 40, 80, 6, 2, 3), + (5, 80, 112, 6, 1, 3), + (5, 112, 192, 6, 2, 4), + (3, 192, 320, 6, 1, 1), +] + +# Data is organised as (r, (w, d)) +# r: image resolution +# w: width scaling +# d: depth scaling +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) + +""" + EfficientNet(config::Symbol; pretrain::Bool = false) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). +See also [`efficientnet`](#). + +# Arguments + + - `config`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet +""" +struct EfficientNet + layers::Any +end +@functor EfficientNet + +function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + cfg_fn = (args...) -> MBConvConfig(args..., EFFICIENTNET_GLOBAL_CONFIGS[config][2]...) + block_configs = [cfg_fn(args...) for args in EFFICIENTNET_BLOCK_CONFIGS] + layers = efficientnet(block_configs; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnet-", config)) + end + return EfficientNet(layers) +end + +(m::EfficientNet)(x) = m.layers(x) + +backbone(m::EfficientNet) = m.layers[1] +classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl new file mode 100644 index 000000000..d2d6a3222 --- /dev/null +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -0,0 +1,67 @@ +# block configs for EfficientNetv2 +const EFFNETV2_CONFIGS = Dict(:small => [ + FusedMBConvConfig(3, 24, 24, 1, 1, 2), + FusedMBConvConfig(3, 24, 48, 4, 2, 4), + FusedMBConvConfig(3, 48, 64, 4, 2, 4), + MBConvConfig(3, 64, 128, 4, 2, 6), + MBConvConfig(3, 128, 160, 6, 1, 9), + MBConvConfig(3, 160, 256, 6, 2, 15)], + :medium => [ + FusedMBConvConfig(3, 24, 24, 1, 1, 3), + FusedMBConvConfig(3, 24, 48, 4, 2, 5), + FusedMBConvConfig(3, 48, 80, 4, 2, 5), + MBConvConfig(3, 80, 160, 4, 2, 7), + MBConvConfig(3, 160, 176, 6, 1, 14), + MBConvConfig(3, 176, 304, 6, 2, 18), + MBConvConfig(3, 304, 512, 6, 1, 5)], + :large => [ + FusedMBConvConfig(3, 32, 32, 1, 1, 4), + FusedMBConvConfig(3, 32, 64, 4, 2, 7), + FusedMBConvConfig(3, 64, 96, 4, 2, 7), + MBConvConfig(3, 96, 192, 4, 2, 10), + MBConvConfig(3, 192, 224, 6, 1, 19), + MBConvConfig(3, 224, 384, 6, 2, 25), + MBConvConfig(3, 384, 640, 6, 1, 7)], + :xlarge => [ + FusedMBConvConfig(3, 32, 32, 1, 1, 4), + FusedMBConvConfig(3, 32, 64, 4, 2, 8), + FusedMBConvConfig(3, 64, 96, 4, 2, 8), + MBConvConfig(3, 96, 192, 4, 2, 16), + MBConvConfig(3, 192, 224, 6, 1, 24), + MBConvConfig(3, 384, 512, 6, 2, 32), + MBConvConfig(3, 512, 768, 6, 1, 8)]) + +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `width_mult`: Controls the number of output feature maps in each block (with 1 + being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" +struct EfficientNetv2 + layers::Any +end +@functor EfficientNetv2 + +function EfficientNetv2(config::Symbol; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) + layers = efficientnet(EFFNETV2_CONFIGS[config]; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnetv2")) + end + return EfficientNetv2(layers) +end + +(m::EfficientNetv2)(x) = m.layers(x) + +backbone(m::EfficientNetv2) = m.layers[1] +classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/inception/googlenet.jl b/src/convnets/inceptions/googlenet.jl similarity index 100% rename from src/convnets/inception/googlenet.jl rename to src/convnets/inceptions/googlenet.jl diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl similarity index 100% rename from src/convnets/inception/inceptionresnetv2.jl rename to src/convnets/inceptions/inceptionresnetv2.jl diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl similarity index 100% rename from src/convnets/inception/inceptionv3.jl rename to src/convnets/inceptions/inceptionv3.jl diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl similarity index 100% rename from src/convnets/inception/inceptionv4.jl rename to src/convnets/inceptions/inceptionv4.jl diff --git a/src/convnets/inception/xception.jl b/src/convnets/inceptions/xception.jl similarity index 92% rename from src/convnets/inception/xception.jl rename to src/convnets/inceptions/xception.jl index 14b5444d6..33222e7be 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -19,8 +19,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int stride::Integer = 1, start_with_relu::Bool = true, grow_at_start::Bool = true) if outchannels != inchannels || stride != 1 - skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride) else skip = [identity] end @@ -35,8 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - dwsep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, - use_norm = (false, false))) + dwsep_conv_bn((3, 3), inc, outc; pad = 1, use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -57,8 +55,8 @@ Creates an Xception model. - `nclasses`: the number of output classes. """ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_norm((3, 3), 32, 64; bias = false)..., + backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 64)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), xception_block(128, 256, 2; stride = 2), xception_block(256, 728, 2; stride = 2), diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl similarity index 95% rename from src/convnets/mobilenet/mobilenetv1.jl rename to src/convnets/mobilenets/mobilenetv1.jl index db9dedbdb..caa899a53 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -27,9 +27,8 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati for _ in 1:nrepeats layer = dw ? dwsep_conv_bn((3, 3), inchannels, outchannels, activation; - stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1, - bias = false) + stride, pad = 1) : + conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1) append!(layers, layer) inchannels = outchannels end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl similarity index 97% rename from src/convnets/mobilenet/mobilenetv2.jl rename to src/convnets/mobilenets/mobilenetv2.jl index 7bc87bcb9..232286309 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -31,7 +31,7 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = _round_channels(32 * width_mult, divisor) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes; bias = false, pad = 1, stride = 2)) + conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, divisor) @@ -44,7 +44,7 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)) + append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl similarity index 98% rename from src/convnets/mobilenet/mobilenetv3.jl rename to src/convnets/mobilenets/mobilenetv3.jl index 68fe2f03b..78c55e144 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -30,8 +30,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1, - bias = false)) + conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) explanes = 0 # building inverted residual blocks for (k, t, c, reduction, activation, stride) in configs @@ -46,7 +45,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, # building last layers headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : max_width - append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) + append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), Dropout(dropout_rate), diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 1e6bb9fee..35bb34fc4 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -29,9 +29,9 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, first_planes = planes ÷ reduction_factor outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, - stride, pad = 1, bias = false) + stride, pad = 1) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, - pad = 1, bias = false) + pad = 1) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -72,12 +72,10 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, - bias = false) + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, - stride, pad = 1, groups = cardinality, bias = false) - conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm, - bias = false) + stride, pad = 1, groups = cardinality) + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -87,7 +85,7 @@ end function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false) return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - pad = SamePad(), stride, bias = false)...) + pad = SamePad(), stride)...) end # Downsample layer using max pooling @@ -95,8 +93,7 @@ function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer norm_layer = BatchNorm, revnorm::Bool = false) pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - bias = false)...) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -178,9 +175,9 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; - norm_layer, revnorm, stride = 2, pad = 1, bias = false)..., + norm_layer, revnorm, stride = 2, pad = 1)..., conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; - norm_layer, pad = 1, bias = false)..., + norm_layer, pad = 1)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) @@ -189,8 +186,7 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - revnorm, - stride = 2, pad = 1, bias = false)...) : + revnorm, stride = 2, pad = 1)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool) end diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index b5dae6663..e308e1125 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -30,17 +30,17 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, outplanes = planes * 4 pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride, - pad = 1, groups = cardinality, bias = false)...) + pad = 1, groups = cardinality)...) for _ in 1:max(1, scale - 1)] reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) layers = [ conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm, bias = false)..., + norm_layer, revnorm)..., chunk$(; size = width, dims = 3), tuplify, reslayer, conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm, bias = false)..., + norm_layer, revnorm)..., attn_fn(outplanes), ] return Chain(filter(!=(identity), layers)...) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index de232d9a3..163c13b68 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -17,7 +17,7 @@ function vgg_block(ifilters::Integer, ofilters::Integer, depth::Integer, batchno layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_norm(k, ifilters, ofilters; pad = p, bias = false)) + append!(layers, conv_norm(k, ifilters, ofilters; pad = p)) else push!(layers, Conv(k, ifilters => ofilters, relu; pad = p)) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 087082c8b..78300b0d0 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,10 +1,11 @@ """ - conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, - use_norm::Bool = true, stride::Integer = 1, pad::Integer = 0, - dilation::Integer = 1, groups::Integer = 1, [bias, weight, init]) + conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + stride::Integer = 1, pad::Integer = 0, dilation::Integer = 1, + groups::Integer = 1, [bias, weight, init]) - conv_norm(kernel_size, inplanes => outplanes, activation = identity; + conv_norm(kernel_size::Dims{2}, inplanes => outplanes, activation = identity; kwargs...) Create a convolution + batch normalization pair with activation. @@ -25,11 +26,14 @@ Create a convolution + batch normalization pair with activation. - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `bias`: bias for the convolution kernel. This is set to `false` by default if + `use_norm = true`. + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, - preact::Bool = false, use_norm::Bool = true, kwargs...) +function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + bias = !use_norm, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -56,30 +60,30 @@ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activatio end end # layers - layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), + layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; - kwargs...) +function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, + activation = identity; kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end # conv + bn layer combination as used by the inception model family matching # the default values used in TensorFlow -function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) +function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu; + kwargs...) return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, - eps = 1.0f-3, bias = false, kwargs...) + eps = 1.0f-3, kwargs...) end """ - dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) + dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: @@ -102,31 +106,32 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `revnorm`: set to `true` to place the batch norm before the convolution - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and second convolution + - `bias`: a tuple of two booleans to specify whether to use bias for the first and second + convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and + `use_norm[1] == true`. - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, +function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, + outplanes::Integer, activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), kwargs...) + use_norm::NTuple{2, Bool} = (true, true), + bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, - revnorm, use_norm = use_norm[1], stride, + revnorm, use_norm = use_norm[1], stride, bias = bias[1], groups = inplanes, kwargs...), conv_norm((1, 1), inplanes, outplanes, activation; eps, - revnorm, use_norm = use_norm[2])) + revnorm, use_norm = use_norm[2], bias = bias[2])) end +# TODO add support for stochastic depth to mbconv and fused_mbconv """ - mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, + mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, reduction::Union{Nothing, Integer} = nothing) - mbconv(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion::Real, - reduction::Union{Nothing, Integer} = nothing) - Create a basic inverted residual block for MobileNet variants ([reference](https://arxiv.org/abs/1905.02244)). @@ -134,46 +139,43 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, - specify the keyword argument `expansion`, which calculates the number of feature - maps in the hidden layer from the number of input feature maps as: - `hidden_planes = inplanes * expansion` + - `explanes`: The number of feature maps in the hidden layer - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)). + in a squeeze and excite layer (see [`squeeze_excite`](#)) """ -function mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing, +function mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] # expand - if inplanes != hidden_planes + if inplanes != explanes append!(layers, - conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false, - norm_layer)) + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end + # depthwise + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) - append!(layers, - squeeze_excite(hidden_planes, inplanes ÷ reduction; activation, - gate_activation = hardσ)) + push!(layers, + squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, + gate_activation = hardσ)) end - # depthwise - append!(layers, - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; bias = false, - norm_layer, stride, pad = SamePad(), groups = hidden_planes)) # project - append!(layers, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)) + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : Chain(layers...) end -function fused_mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, norm_layer = BatchNorm) +function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] if explanes != inplanes diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 31c06c07a..b252584fe 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -116,6 +116,7 @@ function Base.show(io::IO, d::DropBlock) return print(io, ")") end +# TODO look into "row" mode for stochastic depth """ DropPath(p; [rng = rng_from_array(x)]) diff --git a/test/convnets.jl b/test/convnets.jl index e087ceb0e..34bbb5121 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -161,7 +161,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] x = rand(Float32, r, r, 3, 1) @@ -178,7 +178,7 @@ end end @testset "EfficientNetv2" begin - @testset for config in [:small, :medium, :large, :xlarge] + @testset for config in [:small, :medium, :large] # :xlarge] m = EfficientNetv2(config) @test size(m(x_224)) == (1000, 1) if (EfficientNetv2, config) in PRETRAINED_MODELS From 20a4b3701b6e531c86ccee708d24671adb91498d Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 13 Aug 2022 07:56:59 +0530 Subject: [PATCH 09/34] Fixes --- src/convnets/efficientnets/efficientnetv2.jl | 3 ++- src/layers/conv.jl | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index d2d6a3222..6847e7fa8 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -54,7 +54,8 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnet(EFFNETV2_CONFIGS[config]; inchannels, nclasses) + layers = efficientnet(EFFNETV2_CONFIGS[config]; headplanes = 1280, inchannels, + nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 78300b0d0..c94ceb045 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -181,7 +181,8 @@ function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, if explanes != inplanes # fused expand append!(layers, - conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride)) + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) else From ea13dd57937daef23fdeb3760026155e7795cc6d Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 14 Aug 2022 19:24:52 +0530 Subject: [PATCH 10/34] Some refactors, some consistency, some features --- .github/workflows/CI.yml | 3 +- src/convnets/convmixer.jl | 15 ++-- src/convnets/densenet.jl | 15 ++-- src/convnets/efficientnets/core.jl | 10 +-- src/convnets/inceptions/inceptionresnetv2.jl | 6 +- src/convnets/inceptions/inceptionv4.jl | 6 +- src/convnets/inceptions/xception.jl | 4 +- src/convnets/mobilenets/mobilenetv1.jl | 9 +- src/convnets/mobilenets/mobilenetv2.jl | 9 +- src/convnets/mobilenets/mobilenetv3.jl | 37 +++++--- src/convnets/resnets/core.jl | 19 ++-- src/layers/Layers.jl | 5 +- src/layers/classifier.jl | 93 ++++++++++++++++++++ src/layers/conv.jl | 9 +- src/layers/drop.jl | 44 +++++---- src/layers/mlp.jl | 44 +-------- src/utilities.jl | 8 +- 17 files changed, 215 insertions(+), 121 deletions(-) create mode 100644 src/layers/classifier.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 37cda3263..5304bc317 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,8 +28,7 @@ jobs: suite: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - - '"EfficientNet"' - - '"EfficientNetv2"' + - '"EfficientNet"' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' - '[r"Res2Net", r"Res2NeXt"]' diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 1ca8487a9..bc1a71a5f 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -17,16 +17,21 @@ Creates a ConvMixer model. - `nclasses`: number of classes in the output """ function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), - patch_size::Dims{2} = (7, 7), activation = gelu, + patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) - stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, - stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; + layers = [] + # stem of the model + append!(layers, + conv_norm(patch_size, inchannels, planes, activation; preact = true, + stride = patch_size[1])) + # stages of the model + stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; preact = true, groups = planes, pad = SamePad())), +), conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] - return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses)) + append!(layers, stages) + return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate)) end const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index ca81b78ea..a7c367c1c 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates) end """ - densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000) + densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a DenseNet model ([reference](https://arxiv.org/abs/1608.06993)). @@ -66,10 +67,11 @@ Create a DenseNet model - `growth_rates`: the growth rates of output feature maps within each [`dense_block`](#) (a vector of vectors) - `reduction`: the factor by which the number of feature maps is scaled across each transition + - `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes """ -function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3, - nclasses::Integer = 1000) +function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] append!(layers, conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) @@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: inplanes = floor(Int, outplanes * reduction) end push!(layers, BatchNorm(outplanes, relu)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end """ @@ -100,9 +102,10 @@ Create a DenseNet model - `nclasses`: the number of output classes """ function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, - reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) + reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; - reduction, inchannels, nclasses) + reduction, dropout_rate, inchannels, nclasses) end const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 1059cb538..7a221c0e4 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -4,13 +4,13 @@ struct MBConvConfig <: _MBConfig kernel_size::Dims{2} inplanes::Integer outplanes::Integer - expansion::Number + expansion::Real stride::Integer nrepeats::Integer end function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Number, stride::Integer, nrepeats::Integer, - width_mult::Number = 1, depth_mult::Number = 1) + expansion::Real, stride::Integer, nrepeats::Integer, + width_mult::Real = 1, depth_mult::Real = 1) inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) nrepeats = ceil(Int, nrepeats * depth_mult) @@ -35,12 +35,12 @@ struct FusedMBConvConfig <: _MBConfig kernel_size::Dims{2} inplanes::Integer outplanes::Integer - expansion::Number + expansion::Real stride::Integer nrepeats::Integer end function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Number, stride::Integer, nrepeats::Integer) + expansion::Real, stride::Integer, nrepeats::Integer) return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, stride, nrepeats) end diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index 7f462c0cf..bd88648e9 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -64,7 +64,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -72,10 +72,10 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, +function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index b43f6bc1d..13d40da25 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -85,7 +85,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -93,10 +93,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, +function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 33222e7be..171bddd19 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -43,14 +43,14 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end """ - xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) + xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index caa899a53..542edec81 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -1,5 +1,6 @@ """ - mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, + mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). @@ -16,10 +17,12 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ -function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, +function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] for (dw, outchannels, stride, nrepeats) in config @@ -33,7 +36,7 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati inchannels = outchannels end end - return Chain(Chain(layers...), create_classifier(inchannels, nclasses)) + return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 232286309..d81256968 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -20,7 +20,7 @@ Create a MobileNetv2 model. (with 1 being the default in the paper) - `max_width`: The maximum number of feature maps in any layer of the network - `divisor`: The divisor used to round the number of feature maps in each block - - `dropout_rate`: rate of dropout in the classifier head + - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ @@ -33,12 +33,13 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - for (t, c, n, s, a) in configs + for (t, c, n, s, activation) in configs outplanes = _round_channels(c * width_mult, divisor) for i in 1:n + stride = i == 1 ? s : 1 push!(layers, - mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a; - stride = i == 1 ? s : 1)) + mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, + activation; stride)) inplanes = outplanes end end diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 78c55e144..82265e125 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -1,7 +1,7 @@ """ mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv3 model. ([reference](https://arxiv.org/abs/1905.02244)). @@ -19,12 +19,14 @@ Create a MobileNetv3 model. - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) - - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, dropout_rate = 0.2, + max_width::Integer = 1024, reduced_tail::Bool = false, + tail_dilated::Bool = false, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) @@ -32,25 +34,34 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) explanes = 0 + nstages = length(configs) + reduced_divider = 1 # building inverted residual blocks - for (k, t, c, reduction, activation, stride) in configs + for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs) + dilation = 1 + if nstages - i <= 2 + if reduced_tail + reduced_divider = 2 + c /= reduced_divider + end + if tail_dilated + dilation = 2 + end + end # inverted residual layers outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, mbconv((k, k), inplanes, explanes, outplanes, activation; - stride, reduction)) + stride, reduction, dilation)) inplanes = outplanes end # building last layers - headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : - max_width + headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8) append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) - classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(explanes, headplanes, hardswish), - Dropout(dropout_rate), - Dense(headplanes, nclasses)) - return Chain(Chain(layers...), classifier) + return Chain(Chain(layers...), + create_classifier(explanes, headplanes, nclasses, + (hardswish, identity); dropout_rate)) end # Layer configurations for small and large models for MobileNetv3 diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 35bb34fc4..458481d73 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -27,12 +27,11 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor - outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, stride, pad = 1) - conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, + conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm, pad = 1) - layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), + layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes), drop_path] return Chain(filter!(!=(identity), layers)...) end @@ -201,7 +200,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; expansion::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) @@ -236,7 +235,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; expansion::Integer = 4, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) @@ -295,8 +294,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) + use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, + dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) # Block builder @@ -319,8 +318,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0" - @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0" + @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" + @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, @@ -347,7 +346,7 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) - +# larger ResNet-like models const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 72ace2c2c..45615df5e 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -28,7 +28,10 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mlp.jl") -export mlp_block, gated_mlp_block, create_fc, create_classifier +export mlp_block, gated_mlp_block + +include("classifier.jl") +export create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl new file mode 100644 index 000000000..bebdc4099 --- /dev/null +++ b/src/layers/classifier.jl @@ -0,0 +1,93 @@ +""" + create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + +Creates a classifier head to be used for models. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `activation`: activation function to use + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # Dropout is applied after the pooling layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # Fully-connected layer + if use_conv + push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) + else + push!(classifier, Dense(inplanes => nclasses, activation)) + end + return Chain(classifier...) +end + +""" + create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + +Creates a classifier head to be used for models with an extra hidden layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `hidden_planes`: number of hidden feature maps + - `nclasses`: number of output classes + - `activations`: activation functions to use for the hidden and output layers. This is a + tuple of two elements, the first being the activation function for the hidden layer and the + second for the output layer. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. This + is a tuple of two booleans, the first for the hidden layer and the second for the output + layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2, Any} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + fc_layers = [uc ? Conv$(1, 1) : Dense for uc in use_conv] + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv[1] && pool_layer !== identity + if use_conv[1] + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv[1]` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # first fully-connected layer + if !isnothing(hidden_planes) + push!(classifier, fc_layers[1](inplanes => hidden_planes, activations[1])) + end + # Dropout is applied after the first dense layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # second fully-connected layer + push!(classifier, fc_layers[2](hidden_planes => nclasses, activations[2])) + return Chain(classifier...) +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c94ceb045..bb39a0e07 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -146,9 +146,9 @@ Create a basic inverted residual block for MobileNet variants - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)) """ -function mbconv(kernel_size::Dims{2}, inplanes::Integer, - explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, reduction::Union{Nothing, Integer} = nothing, +function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] @@ -158,9 +158,10 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end # depthwise + stride = dilation > 1 ? 1 : stride append!(layers, conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, pad = SamePad(), groups = explanes)) + stride, dilation, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) push!(layers, diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b252584fe..387b562ef 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -20,7 +20,8 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. - `x`: input array - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -56,11 +57,25 @@ dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) The `DropBlock` layer. While training, it zeroes out continguous regions of size `block_size` in the input. During inference, it simply returns the input `x`. +It can be used in two ways: either with all blocks having the same survival probability +or with a linear scaling rule across the blocks. This is performed only at training time. +At test time, the `DropBlock` layer is equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropBlock` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. + ((reference)[https://arxiv.org/abs/1810.12890]) # Arguments - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -90,11 +105,8 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_s function (m::DropBlock)(x) _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) - if Flux._isactive(m) - return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) - else - return x - end + return Flux._isactive(m) ? + dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) : x end function Flux.testmode!(m::DropBlock, mode = true) @@ -103,7 +115,7 @@ end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, rng = rng_from_array()) - if drop_block_prob == 0.0 + if isnothing(drop_block_prob) return identity end return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) @@ -120,8 +132,8 @@ end """ DropPath(p; [rng = rng_from_array(x)]) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and -`identity` otherwise. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 ≤ p ≤ 1` and +`identity` if p is `nothing`. ([reference](https://arxiv.org/abs/1603.09382)) This layer can be used to drop certain blocks in a residual structure and allow them to @@ -134,10 +146,10 @@ equivalent to `identity`. In the case of the linear scaling rule, the calculations of survival probabilities for each block may lead to a survival probability > 1 for a given block. This will lead to - `DropPath` returning `identity`, which may not be desirable. This usually happens with - a low number of blocks and a high base survival probability, so it is recommended to - use a fixed base survival probability across blocks. If this is not possible, then - a lower base survival probability is recommended. + `DropPath` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. # Arguments @@ -146,4 +158,6 @@ equivalent to `identity`. for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU. """ -DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity +function DropPath(p; rng = rng_from_array()) + return isnothing(p) ? identity : Dropout(p; dims = 4, rng) +end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 467df30a4..e6336de9c 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,3 +1,4 @@ +# TODO @theabhirath figure out consistent behaviour for dropout rates - 0.0 vs `nothing` """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; dropout_rate = 0., activation = gelu) @@ -45,46 +46,3 @@ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) - -""" - create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = 0.0, use_conv::Bool = false) - -Creates a classifier head to be used for models. - -# Arguments - - - `inplanes`: number of input feature maps - - `nclasses`: number of output classes - - `activation`: activation function to use - - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with - any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. - - `dropout_rate`: dropout rate used in the classifier head. - - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. -""" -function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = nothing) - # Decide whether to flatten the input or not - flatten_in_pool = !use_conv && pool_layer !== identity - if use_conv - @assert pool_layer === identity - "`pool_layer` must be identity if `use_conv` is true" - end - classifier = [] - if flatten_in_pool - push!(classifier, pool_layer, MLUtils.flatten) - else - push!(classifier, pool_layer) - end - # Dropout is applied after the pooling layer - isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) - # Fully-connected layer - if use_conv - push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) - else - push!(classifier, Dense(inplanes => nclasses, activation)) - end - return Chain(classifier...) -end diff --git a/src/utilities.jl b/src/utilities.jl index 4a611b5a2..09074b0e8 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -66,13 +66,17 @@ function _maybe_big_show(io, model) end """ - linear_scheduler(drop_path_rate = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_rate = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_rate::Nothing; depth::Integer) -Returns the dropout rates for a given depth using the linear scaling rule. +Returns the dropout rates for a given depth using the linear scaling rule. If the +`drop_rate` is `nothing`, it returns a `Vector` of length `depth` with all values +equal to `nothing`. """ function linear_scheduler(drop_rate = 0.0; depth::Integer, start_value = 0.0) return LinRange(start_value, drop_rate, depth) end +linear_scheduler(drop_rate::Nothing; depth::Integer) = fill(drop_rate, depth) # Utility function for depth and configuration checks in models function _checkconfig(config, configs) From c99895131bcedef5c72e44c2f6f6a61590b28f0a Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Mon, 15 Aug 2022 17:11:33 +0530 Subject: [PATCH 11/34] The real hero was `block_idx` all along --- src/convnets/efficientnets/core.jl | 101 ++++++++----------- src/convnets/efficientnets/efficientnet.jl | 8 +- src/convnets/efficientnets/efficientnetv2.jl | 60 +++++------ src/convnets/resnets/core.jl | 56 +++++----- src/convnets/resnets/res2net.jl | 1 - src/layers/Layers.jl | 5 +- src/layers/conv.jl | 68 ------------- src/layers/mbconv.jl | 65 ++++++++++++ 8 files changed, 176 insertions(+), 188 deletions(-) create mode 100644 src/layers/mbconv.jl diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 7a221c0e4..139252a3b 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,78 +1,61 @@ -abstract type _MBConfig end - -struct MBConvConfig <: _MBConfig - kernel_size::Dims{2} - inplanes::Integer - outplanes::Integer - expansion::Real - stride::Integer - nrepeats::Integer -end -function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Real, stride::Integer, nrepeats::Integer, - width_mult::Real = 1, depth_mult::Real = 1) +function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, + stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) + depth_mult, width_mult = scalings + k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) - nrepeats = ceil(Int, nrepeats * depth_mult) - return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, - stride, nrepeats) + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = mbconv((k, k), inplanes, explanes, outplanes, swish; norm_layer, + stride, reduction = 4) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, ceil(Int, nrepeats * depth_mult) end -function efficientnetblock(m::MBConvConfig, norm_layer) - layers = [] - explanes = _round_channels(m.inplanes * m.expansion, 8) - push!(layers, - mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer, - stride = m.stride, reduction = 4)) - explanes = _round_channels(m.outplanes * m.expansion, 8) - append!(layers, - [mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer, - stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)]) - return Chain(layers...) +function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, + stage_idx::Integer; norm_layer = BatchNorm) + k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = fused_mbconv((k, k), inplanes, explanes, outplanes, swish; + norm_layer, stride) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats end -struct FusedMBConvConfig <: _MBConfig - kernel_size::Dims{2} - inplanes::Integer - outplanes::Integer - expansion::Real - stride::Integer - nrepeats::Integer -end -function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Real, stride::Integer, nrepeats::Integer) - return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, - stride, nrepeats) -end - -function efficientnetblock(m::FusedMBConvConfig, norm_layer) - layers = [] - explanes = _round_channels(m.inplanes * m.expansion, 8) - push!(layers, - fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; - norm_layer, stride = m.stride)) - explanes = _round_channels(m.outplanes * m.expansion, 8) - append!(layers, - [fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; - norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)]) - return Chain(layers...) +function efficientnet_builder(block_configs::AbstractVector{NTuple{6, Int}}, + residual_fns::AbstractVector; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) + bxs = [residual_fn(block_configs, stage_idx; scalings, norm_layer) + for (stage_idx, residual_fn) in enumerate(residual_fns)] + return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end -function efficientnet(block_configs::AbstractVector{<:_MBConfig}; - headplanes::Union{Nothing, Integer} = nothing, +function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, + residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), + headplanes::Integer = _round_channels(block_configs[end][3] * + scalings[2], 8) * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer, + conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks - append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs]) + get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; + scalings, norm_layer) + append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers - outplanes = block_configs[end].outplanes - headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes append!(layers, - conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad())) + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8), + headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 0bb481dda..bff9d8dde 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -9,7 +9,6 @@ const EFFICIENTNET_BLOCK_CONFIGS = [ (5, 112, 192, 6, 2, 4), (3, 192, 320, 6, 1, 1), ] - # Data is organised as (r, (w, d)) # r: image resolution # w: width scaling @@ -44,9 +43,10 @@ end function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - cfg_fn = (args...) -> MBConvConfig(args..., EFFICIENTNET_GLOBAL_CONFIGS[config][2]...) - block_configs = [cfg_fn(args...) for args in EFFICIENTNET_BLOCK_CONFIGS] - layers = efficientnet(block_configs; inchannels, nclasses) + scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] + layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS, + fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS)); + scalings, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index 6847e7fa8..ecbeed07a 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,35 +1,36 @@ # block configs for EfficientNetv2 +# data organised as (k, i, o, e, s, n) const EFFNETV2_CONFIGS = Dict(:small => [ - FusedMBConvConfig(3, 24, 24, 1, 1, 2), - FusedMBConvConfig(3, 24, 48, 4, 2, 4), - FusedMBConvConfig(3, 48, 64, 4, 2, 4), - MBConvConfig(3, 64, 128, 4, 2, 6), - MBConvConfig(3, 128, 160, 6, 1, 9), - MBConvConfig(3, 160, 256, 6, 2, 15)], + (3, 24, 24, 1, 1, 2), + (3, 24, 48, 4, 2, 4), + (3, 48, 64, 4, 2, 4), + (3, 64, 128, 4, 2, 6), + (3, 128, 160, 6, 1, 9), + (3, 160, 256, 6, 2, 15)], :medium => [ - FusedMBConvConfig(3, 24, 24, 1, 1, 3), - FusedMBConvConfig(3, 24, 48, 4, 2, 5), - FusedMBConvConfig(3, 48, 80, 4, 2, 5), - MBConvConfig(3, 80, 160, 4, 2, 7), - MBConvConfig(3, 160, 176, 6, 1, 14), - MBConvConfig(3, 176, 304, 6, 2, 18), - MBConvConfig(3, 304, 512, 6, 1, 5)], + (3, 24, 24, 1, 1, 3), + (3, 24, 48, 4, 2, 5), + (3, 48, 80, 4, 2, 5), + (3, 80, 160, 4, 2, 7), + (3, 160, 176, 6, 1, 14), + (3, 176, 304, 6, 2, 18), + (3, 304, 512, 6, 1, 5)], :large => [ - FusedMBConvConfig(3, 32, 32, 1, 1, 4), - FusedMBConvConfig(3, 32, 64, 4, 2, 7), - FusedMBConvConfig(3, 64, 96, 4, 2, 7), - MBConvConfig(3, 96, 192, 4, 2, 10), - MBConvConfig(3, 192, 224, 6, 1, 19), - MBConvConfig(3, 224, 384, 6, 2, 25), - MBConvConfig(3, 384, 640, 6, 1, 7)], + (3, 32, 32, 1, 1, 4), + (3, 32, 64, 4, 2, 7), + (3, 64, 96, 4, 2, 7), + (3, 96, 192, 4, 2, 10), + (3, 192, 224, 6, 1, 19), + (3, 224, 384, 6, 2, 25), + (3, 384, 640, 6, 1, 7)], :xlarge => [ - FusedMBConvConfig(3, 32, 32, 1, 1, 4), - FusedMBConvConfig(3, 32, 64, 4, 2, 8), - FusedMBConvConfig(3, 64, 96, 4, 2, 8), - MBConvConfig(3, 96, 192, 4, 2, 16), - MBConvConfig(3, 192, 224, 6, 1, 24), - MBConvConfig(3, 384, 512, 6, 2, 32), - MBConvConfig(3, 512, 768, 6, 1, 8)]) + (3, 32, 32, 1, 1, 4), + (3, 32, 64, 4, 2, 8), + (3, 64, 96, 4, 2, 8), + (3, 96, 192, 4, 2, 16), + (3, 192, 224, 6, 1, 24), + (3, 384, 512, 6, 2, 32), + (3, 512, 768, 6, 1, 8)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, @@ -54,8 +55,9 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnet(EFFNETV2_CONFIGS[config]; headplanes = 1280, inchannels, - nclasses) + layers = efficientnet(EFFNETV2_CONFIGS[config], + vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4)); + headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 458481d73..39e283fd0 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -203,6 +203,8 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # Also get `planes_vec` needed for block `inplanes` and `planes` calculations pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) planes_vec = collect(planes_fn(block_repeats)) @@ -265,22 +267,26 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; return get_layers end +# TODO @theabhirath figure out a better name and potentially refactor other CNNs to use this function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) # Construct each stage stages = [] - for (stage_idx, num_blocks) in enumerate(block_repeats) + for (stage_idx, nblocks) in enumerate(block_repeats) # Construct the blocks for each stage - blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...) - for block_idx in 1:num_blocks] + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + return (length(branches) == 1) ? only(branches) : + Parallel(connection, branches...) + end push!(stages, Chain(blocks...)) end return Chain(stages...) end -function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, +function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection) + stage_blocks = resnet_stages(builders, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -302,39 +308,37 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" - get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, - activation, norm_layer, revnorm, attn_fn, - drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, - planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + builder = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, revnorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck - get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, - reduction_factor, activation, norm_layer, revnorm, - attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, - planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + builder = bottleneck_builder(block_repeats; inplanes, cardinality, + base_width, reduction_factor, activation, norm_layer, + revnorm, attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" - @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" - @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" + @assert isnothing(drop_block_rate) + "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" + @assert isnothing(drop_path_rate) + "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" + @assert reduction_factor == 1 + "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + downsample_tuple = downsample_opt, kwargs...) else # TODO: write better message when we have link to dev docs for resnet throw(ArgumentError("Unknown block type $block_type")) end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, get_layers, block_repeats, - connection$activation, classifier_fn) + return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)), + block_repeats, connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index e308e1125..08f2c87ee 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -54,7 +54,6 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; attn_fn = planes -> identity, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) - planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) # This is needed for block `inplanes` and `planes` calculations diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 45615df5e..9bdf1f913 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -19,7 +19,7 @@ include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, basic_conv_bn, dwsep_conv_bn, mbconv, fused_mbconv +export conv_norm, basic_conv_bn, dwsep_conv_bn include("drop.jl") export DropBlock, DropPath @@ -27,6 +27,9 @@ export DropBlock, DropPath include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens +include("mbconv.jl") +export mbconv, fused_mbconv + include("mlp.jl") export mlp_block, gated_mlp_block diff --git a/src/layers/conv.jl b/src/layers/conv.jl index bb39a0e07..d81d76c9c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -125,71 +125,3 @@ function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, conv_norm((1, 1), inplanes, outplanes, activation; eps, revnorm, use_norm = use_norm[2], bias = bias[2])) end - -# TODO add support for stochastic depth to mbconv and fused_mbconv -""" - mbconv(kernel_size, inplanes::Integer, explanes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing) - -Create a basic inverted residual block for MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `kernel_size`: kernel size of the convolutional layers - - `inplanes`: number of input feature maps - - `explanes`: The number of feature maps in the hidden layer - - `outplanes`: The number of output feature maps - - `activation`: The activation function for the first two convolution layer - - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)) -""" -function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, - norm_layer = BatchNorm) - @assert stride in [1, 2] "`stride` has to be 1 or 2" - layers = [] - # expand - if inplanes != explanes - append!(layers, - conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) - end - # depthwise - stride = dilation > 1 ? 1 : stride - append!(layers, - conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, dilation, pad = SamePad(), groups = explanes)) - # squeeze-excite layer - if !isnothing(reduction) - push!(layers, - squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, - gate_activation = hardσ)) - end - # project - append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) - return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : - Chain(layers...) -end - -function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, - explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, norm_layer = BatchNorm) - @assert stride in [1, 2] "`stride` has to be 1 or 2" - layers = [] - if explanes != inplanes - # fused expand - append!(layers, - conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, - pad = SamePad())) - # project - append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) - else - append!(layers, - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) - end - return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : - Chain(layers...) -end diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl new file mode 100644 index 000000000..af17505ea --- /dev/null +++ b/src/layers/mbconv.jl @@ -0,0 +1,65 @@ +# TODO add support for stochastic depth to mbconv and fused_mbconv +""" + mbconv(kernel_size, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Integer} = nothing) + +Create a basic inverted residual block for MobileNet variants +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps + - `explanes`: The number of feature maps in the hidden layer + - `outplanes`: The number of output feature maps + - `activation`: The activation function for the first two convolution layer + - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 + - `reduction`: The reduction factor for the number of hidden feature maps + in a squeeze and excite layer (see [`squeeze_excite`](#)) +""" +function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + layers = [] + # expand + if inplanes != explanes + append!(layers, + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) + end + # depthwise + stride = dilation > 1 ? 1 : stride + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, dilation, pad = SamePad(), groups = explanes)) + # squeeze-excite layer + if !isnothing(reduction) + push!(layers, + squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, + gate_activation = hardσ)) + end + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) + return Chain(layers...) +end + +function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `fused_mbconv`" + layers = [] + if explanes != inplanes + # fused expand + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) + else + append!(layers, + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) + end + return Chain(layers...) +end From fc03d70c8167c5ba7a8757bc6d125dcf7cecf2ff Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Mon, 15 Aug 2022 22:51:42 +0530 Subject: [PATCH 12/34] Fix minor hiccups --- src/convnets/efficientnets/core.jl | 15 ++++----- src/convnets/efficientnets/efficientnetv2.jl | 3 +- src/convnets/resnets/core.jl | 32 +++++++++++--------- src/convnets/resnets/res2net.jl | 1 + 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 139252a3b..22ddf8172 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,7 +1,7 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - depth_mult, width_mult = scalings + width_mult, depth_mult = scalings k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) @@ -17,7 +17,8 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, end function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; norm_layer = BatchNorm) + stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes @@ -40,22 +41,22 @@ end function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), - headplanes::Integer = _round_channels(block_configs[end][3] * - scalings[2], 8) * 4, + headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer, - stride = 2, pad = SamePad())) + conv_norm((3, 3), inchannels, + _round_channels(block_configs[1][2] * scalings[1], 8), swish; + norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; scalings, norm_layer) append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index ecbeed07a..d9a9d0d77 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -56,7 +56,8 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) layers = efficientnet(EFFNETV2_CONFIGS[config], - vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4)); + vcat(fill(fused_mbconv_builder, 3), + fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 39e283fd0..95291bc6f 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -275,7 +275,7 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con # Construct the blocks for each stage blocks = map(1:nblocks) do block_idx branches = get_layers(stage_idx, block_idx) - return (length(branches) == 1) ? only(branches) : + return length(branches) == 1 ? only(branches) : Parallel(connection, branches...) end push!(stages, Chain(blocks...)) @@ -283,10 +283,10 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con return Chain(stages...) end -function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer}, +function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(builders, block_repeats, connection) + stage_blocks = resnet_stages(get_layers, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -308,17 +308,19 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" - builder = basicblock_builder(block_repeats; inplanes, reduction_factor, - activation, norm_layer, revnorm, attn_fn, - drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, kwargs...) + get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, revnorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck - builder = bottleneck_builder(block_repeats; inplanes, cardinality, - base_width, reduction_factor, activation, norm_layer, - revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, kwargs...) + get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" @@ -337,8 +339,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)), - block_repeats, connection$activation, classifier_fn) + return resnet((imsize..., inchannels), stem, get_layers, block_repeats, + connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 08f2c87ee..e308e1125 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -54,6 +54,7 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; attn_fn = planes -> identity, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) + planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) # This is needed for block `inplanes` and `planes` calculations From f2461a5e426f153e0d006a69a8a04e6a4cc17aee Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 16 Aug 2022 12:01:37 +0530 Subject: [PATCH 13/34] Moving closer to the one true function --- src/convnets/efficientnets/core.jl | 52 +++++++++-------- src/convnets/efficientnets/efficientnet.jl | 18 +++--- src/convnets/efficientnets/efficientnetv2.jl | 61 ++++++++++---------- src/convnets/mobilenets/mobilenetv2.jl | 43 +++++++------- src/layers/mbconv.jl | 3 +- 5 files changed, 90 insertions(+), 87 deletions(-) diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 22ddf8172..ee94ff270 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,62 +1,66 @@ -function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) +function mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, + round_fn = planes -> _round_channels(planes, 8)) width_mult, depth_mult = scalings - k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] - inplanes = _round_channels(inplanes * width_mult, 8) + k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + inplanes = round_fn(inplanes * width_mult) outplanes = _round_channels(outplanes * width_mult, 8) function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = mbconv((k, k), inplanes, explanes, outplanes, swish; norm_layer, - stride, reduction = 4) + block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, ceil(Int, nrepeats * depth_mult) end -function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) - k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] +function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) + k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = fused_mbconv((k, k), inplanes, explanes, outplanes, swish; + block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, stride) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, nrepeats end -function efficientnet_builder(block_configs::AbstractVector{NTuple{6, Int}}, - residual_fns::AbstractVector; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - bxs = [residual_fn(block_configs, stage_idx; scalings, norm_layer) +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, + residual_fns::AbstractVector; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) + bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) for (stage_idx, residual_fn) in enumerate(residual_fns)] return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end -function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, - residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), +function efficientnet(block_configs::AbstractVector{<:Tuple}, + residual_fns::AbstractVector; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, - _round_channels(block_configs[1][2] * scalings[1], 8), swish; - norm_layer, stride = 2, pad = SamePad())) + conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8), + swish; norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks - get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; - scalings, norm_layer) + get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns; + inplanes, scalings, norm_layer) append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), + conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index bff9d8dde..495b8658f 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -1,13 +1,13 @@ # block configs for EfficientNet const EFFICIENTNET_BLOCK_CONFIGS = [ - # k, i, o, e, s, n - (3, 32, 16, 1, 1, 1), - (3, 16, 24, 6, 2, 2), - (5, 24, 40, 6, 2, 2), - (3, 40, 80, 6, 2, 3), - (5, 80, 112, 6, 1, 3), - (5, 112, 192, 6, 2, 4), - (3, 192, 320, 6, 1, 1), + # k, c, e, s, n, r, a + (3, 16, 1, 1, 1, 4, swish), + (3, 24, 6, 2, 2, 4, swish), + (5, 40, 6, 2, 2, 4, swish), + (3, 80, 6, 2, 3, 4, swish), + (5, 112, 6, 1, 3, 4, swish), + (5, 192, 6, 2, 4, 4, swish), + (3, 320, 6, 1, 1, 4, swish), ] # Data is organised as (r, (w, d)) # r: image resolution @@ -46,7 +46,7 @@ function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Intege scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS, fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS)); - scalings, inchannels, nclasses) + inplanes = 32, scalings, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index d9a9d0d77..69f67cc81 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,36 +1,36 @@ # block configs for EfficientNetv2 -# data organised as (k, i, o, e, s, n) +# data organised as (k, c, e, s, n, r, a) const EFFNETV2_CONFIGS = Dict(:small => [ - (3, 24, 24, 1, 1, 2), - (3, 24, 48, 4, 2, 4), - (3, 48, 64, 4, 2, 4), - (3, 64, 128, 4, 2, 6), - (3, 128, 160, 6, 1, 9), - (3, 160, 256, 6, 2, 15)], + (3, 24, 1, 1, 2, nothing, swish), + (3, 48, 4, 2, 4, nothing, swish), + (3, 64, 4, 2, 4, nothing, swish), + (3, 128, 4, 2, 6, 4, swish), + (3, 160, 6, 1, 9, 4, swish), + (3, 256, 6, 2, 15, 4, swish)], :medium => [ - (3, 24, 24, 1, 1, 3), - (3, 24, 48, 4, 2, 5), - (3, 48, 80, 4, 2, 5), - (3, 80, 160, 4, 2, 7), - (3, 160, 176, 6, 1, 14), - (3, 176, 304, 6, 2, 18), - (3, 304, 512, 6, 1, 5)], + (3, 24, 1, 1, 3, nothing, swish), + (3, 48, 4, 2, 5, nothing, swish), + (3, 80, 4, 2, 5, nothing, swish), + (3, 160, 4, 2, 7, 4, swish), + (3, 176, 6, 1, 14, 4, swish), + (3, 304, 6, 2, 18, 4, swish), + (3, 512, 6, 1, 5, 4, swish)], :large => [ - (3, 32, 32, 1, 1, 4), - (3, 32, 64, 4, 2, 7), - (3, 64, 96, 4, 2, 7), - (3, 96, 192, 4, 2, 10), - (3, 192, 224, 6, 1, 19), - (3, 224, 384, 6, 2, 25), - (3, 384, 640, 6, 1, 7)], + (3, 32, 1, 1, 4, nothing, swish), + (3, 64, 4, 2, 7, nothing, swish), + (3, 96, 4, 2, 7, nothing, swish), + (3, 192, 4, 2, 10, 4, swish), + (3, 224, 6, 1, 19, 4, swish), + (3, 384, 6, 2, 25, 4, swish), + (3, 640, 6, 1, 7, 4, swish)], :xlarge => [ - (3, 32, 32, 1, 1, 4), - (3, 32, 64, 4, 2, 8), - (3, 64, 96, 4, 2, 8), - (3, 96, 192, 4, 2, 16), - (3, 192, 224, 6, 1, 24), - (3, 384, 512, 6, 2, 32), - (3, 512, 768, 6, 1, 8)]) + (3, 32, 1, 1, 4, nothing, swish), + (3, 64, 4, 2, 8, nothing, swish), + (3, 96, 4, 2, 8, nothing, swish), + (3, 192, 4, 2, 16, 4, swish), + (3, 384, 6, 1, 24, 4, swish), + (3, 512, 6, 2, 32, 4, swish), + (3, 768, 6, 1, 8, 4, swish)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, @@ -58,9 +58,10 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false, layers = efficientnet(EFFNETV2_CONFIGS[config], vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); - headplanes = 1280, inchannels, nclasses) + inplanes = EFFNETV2_CONFIGS[config][1][2], headplanes = 1280, + inchannels, nclasses) if pretrain - loadpretrain!(layers, string("efficientnetv2")) + loadpretrain!(layers, string("efficientnetv2-", config)) end return EfficientNetv2(layers) end diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index d81256968..16b59c3a8 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -14,7 +14,6 @@ Create a MobileNetv2 model. + `c`: The number of output feature maps + `n`: The number of times a block is repeated + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper) @@ -24,41 +23,39 @@ Create a MobileNetv2 model. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ -function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, +function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1280, divisor::Integer = 8, + inplanes::Integer = 32, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer - inplanes = _round_channels(32 * width_mult, divisor) + inplanes = _round_channels(inplanes * width_mult, divisor) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - for (t, c, n, s, activation) in configs - outplanes = _round_channels(c * width_mult, divisor) - for i in 1:n - stride = i == 1 ? s : 1 - push!(layers, - mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, - activation; stride)) - inplanes = outplanes - end - end + get_layers, block_repeats = mbconv_stack_builder(block_configs, + fill(mbconv_builder, + length(block_configs)); + inplanes) + append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6)) + append!(layers, + conv_norm((1, 1), _round_channels(block_configs[end][2], 8), + outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 const MOBILENETV2_CONFIGS = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), + # k, c, e, s, n, r, a + (3, 16, 1, 1, 1, nothing, relu6), + (3, 24, 6, 2, 2, nothing, relu6), + (3, 32, 6, 2, 3, nothing, relu6), + (3, 64, 6, 2, 4, nothing, relu6), + (3, 96, 6, 1, 3, nothing, relu6), + (3, 160, 6, 2, 3, nothing, relu6), + (3, 320, 6, 1, 1, nothing, relu6), ] """ diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index af17505ea..eeac44dbd 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -59,7 +59,8 @@ function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) else append!(layers, - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) + conv_norm(kernel_size, inplanes, outplanes, activation; pad = SamePad(), + norm_layer, stride)) end return Chain(layers...) end From 9f6b987e555592d7680ffe09b2723f74e9f6693b Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 16 Aug 2022 15:06:29 +0530 Subject: [PATCH 14/34] Some more reorganisation --- src/Metalhead.jl | 5 ++ src/convnets/builders/core.jl | 15 +++++ src/convnets/builders/mbconv.jl | 44 ++++++++++++ src/convnets/builders/resblocks.jl | 71 ++++++++++++++++++++ src/convnets/efficientnets/core.jl | 47 +------------ src/convnets/mobilenets/mobilenetv2.jl | 2 +- src/convnets/resnets/core.jl | 93 +------------------------- 7 files changed, 140 insertions(+), 137 deletions(-) create mode 100644 src/convnets/builders/core.jl create mode 100644 src/convnets/builders/mbconv.jl create mode 100644 src/convnets/builders/resblocks.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 6b0179f45..8c8800e84 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -19,6 +19,11 @@ include("layers/Layers.jl") using .Layers # CNN models +## Builders +include("convnets/builders/core.jl") +include("convnets/builders/mbconv.jl") +include("convnets/builders/resblocks.jl") +## AlexNet and VGG include("convnets/alexnet.jl") include("convnets/vgg.jl") ## ResNets diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl new file mode 100644 index 000000000..413c78c27 --- /dev/null +++ b/src/convnets/builders/core.jl @@ -0,0 +1,15 @@ +# TODO potentially refactor other CNNs to use this +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) + # Construct each stage + stages = [] + for (stage_idx, nblocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + return length(branches) == 1 ? only(branches) : + Parallel(connection, branches...) + end + push!(stages, Chain(blocks...)) + end + return Chain(stages...) +end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl new file mode 100644 index 000000000..3312c1cfe --- /dev/null +++ b/src/convnets/builders/mbconv.jl @@ -0,0 +1,44 @@ +function mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, + round_fn = planes -> _round_channels(planes, 8)) + width_mult, depth_mult = scalings + k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + inplanes = round_fn(inplanes * width_mult) + outplanes = _round_channels(outplanes * width_mult, 8) + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, ceil(Int, nrepeats * depth_mult) +end + +function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) + k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; + norm_layer, stride) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats +end + +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, + residual_fns::AbstractVector; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) + bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) + for (stage_idx, residual_fn) in enumerate(residual_fns)] + return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) +end diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl new file mode 100644 index 000000000..2bf03746f --- /dev/null +++ b/src/convnets/builders/resblocks.jl @@ -0,0 +1,71 @@ +function basicblock_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 1, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + drop_block_rate = nothing, drop_path_rate = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # Also get `planes_vec` needed for block `inplanes` and `planes` calculations + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = stride != 1 || inplanes != planes * expansion ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = basicblock(inplanes, planes; stride, reduction_factor, activation, + norm_layer, revnorm, attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end + +function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + drop_block_rate = nothing, drop_path_rate = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = stride != 1 || inplanes != planes * expansion ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = bottleneck(inplanes, planes; stride, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end \ No newline at end of file diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index ee94ff270..c37fb8c93 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,48 +1,3 @@ -function mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, - round_fn = planes -> _round_channels(planes, 8)) - width_mult, depth_mult = scalings - k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] - inplanes = round_fn(inplanes * width_mult) - outplanes = _round_channels(outplanes * width_mult, 8) - function get_layers(block_idx) - inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) - stride = block_idx == 1 ? stride : 1 - block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, - stride, reduction) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) - end - return get_layers, ceil(Int, nrepeats * depth_mult) -end - -function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] - function get_layers(block_idx) - inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) - stride = block_idx == 1 ? stride : 1 - block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; - norm_layer, stride) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) - end - return get_layers, nrepeats -end - -function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, - residual_fns::AbstractVector; inplanes::Integer, - scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) - bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) - for (stage_idx, residual_fn) in enumerate(residual_fns)] - return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) -end - function efficientnet(block_configs::AbstractVector{<:Tuple}, residual_fns::AbstractVector; inplanes::Integer, scalings::NTuple{2, Real} = (1, 1), @@ -57,7 +12,7 @@ function efficientnet(block_configs::AbstractVector{<:Tuple}, # building inverted residual blocks get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns; inplanes, scalings, norm_layer) - append!(layers, resnet_stages(get_layers, block_repeats, +)) + append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers append!(layers, conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8), diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 16b59c3a8..ad41ff967 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -37,7 +37,7 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = fill(mbconv_builder, length(block_configs)); inplanes) - append!(layers, resnet_stages(get_layers, block_repeats, +)) + append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 95291bc6f..83ab9da04 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -195,98 +195,10 @@ function resnet_planes(block_repeats::AbstractVector{<:Integer}) for (stage_idx, stages) in enumerate(block_repeats)) end -function basicblock_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, reduction_factor::Integer = 1, - expansion::Integer = 1, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = nothing, drop_path_rate = nothing, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # Also get `planes_vec` needed for block `inplanes` and `planes` calculations - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # This is also needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = stride != 1 || inplanes != planes * expansion ? - downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - block = basicblock(inplanes, planes; stride, reduction_factor, activation, - norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - -function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, cardinality::Integer = 1, - base_width::Integer = 64, reduction_factor::Integer = 1, - expansion::Integer = 4, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = nothing, drop_path_rate = nothing, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # This is also needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = stride != 1 || inplanes != planes * expansion ? - downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - block = bottleneck(inplanes, planes; stride, cardinality, base_width, - reduction_factor, activation, norm_layer, revnorm, - attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - -# TODO @theabhirath figure out a better name and potentially refactor other CNNs to use this -function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) - # Construct each stage - stages = [] - for (stage_idx, nblocks) in enumerate(block_repeats) - # Construct the blocks for each stage - blocks = map(1:nblocks) do block_idx - branches = get_layers(stage_idx, block_idx) - return length(branches) == 1 ? only(branches) : - Parallel(connection, branches...) - end - push!(stages, Chain(blocks...)) - end - return Chain(stages...) -end - function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection) + stage_blocks = cnn_stages(get_layers, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -352,7 +264,8 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) -# larger ResNet-like models +# block configurations for larger ResNet-like models that do not use +# depths 18 and 34 const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) From 785a95aea78d72e743602ea5e6105387c03ca4b9 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 18 Aug 2022 14:28:40 +0530 Subject: [PATCH 15/34] Huge refactor of MobileNet and EfficientNet families The rise of the builders --- src/convnets/builders/core.jl | 18 +++- src/convnets/builders/mbconv.jl | 108 +++++++++++++++---- src/convnets/builders/resblocks.jl | 2 +- src/convnets/efficientnets/core.jl | 9 +- src/convnets/efficientnets/efficientnet.jl | 29 +++-- src/convnets/efficientnets/efficientnetv2.jl | 75 ++++++------- src/convnets/mobilenets/mobilenetv1.jl | 41 ++++--- src/convnets/mobilenets/mobilenetv2.jl | 23 ++-- src/convnets/mobilenets/mobilenetv3.jl | 91 +++++++--------- src/convnets/resnets/core.jl | 29 ++--- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 48 --------- src/layers/drop.jl | 6 +- src/layers/mbconv.jl | 108 ++++++++++++++++--- src/layers/selayers.jl | 10 +- 15 files changed, 347 insertions(+), 252 deletions(-) diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl index 413c78c27..e02092eca 100644 --- a/src/convnets/builders/core.jl +++ b/src/convnets/builders/core.jl @@ -11,5 +11,21 @@ function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connec end push!(stages, Chain(blocks...)) end - return Chain(stages...) + return stages +end + +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}) + # Construct each stage + stages = [] + for (stage_idx, nblocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + @assert length(branches)==1 "get_layers should return a single branch for each + block if no connection is specified" + return only(branches) + end + push!(stages, Chain(blocks...)) + end + return stages end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index 3312c1cfe..c90b8215c 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -1,44 +1,108 @@ -function mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, - round_fn = planes -> _round_channels(planes, 8)) +function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + width_mult::Number; norm_layer = BatchNorm, kwargs...) + _, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] + outplanes = floor(Int, outplanes * width_mult) + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + stride = block_idx == 1 ? stride : 1 + block = Chain(dwsep_conv_bn((k, k), inplanes, outplanes, activation; + stride, pad = SamePad(), norm_layer, kwargs...)...) + return (block,) + end + return get_layers, nrepeats +end + +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + scalings::NTuple{2, Real}; norm_layer = BatchNorm, + round_fn = planes -> _round_channels(planes, 8), kwargs...) width_mult, depth_mult = scalings - k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] inplanes = round_fn(inplanes * width_mult) outplanes = _round_channels(outplanes * width_mult, 8) - function get_layers(block_idx) + function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, - stride, reduction) + block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction, no_skip = true, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, ceil(Int, nrepeats * depth_mult) end -function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] - function get_layers(block_idx) +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + width_mult::Real; norm_layer = BatchNorm, kwargs...) + block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction, no_skip = true, kwargs...) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats +end + +function fused_mbconv_builder(block_configs, inplanes::Integer, + stage_idx::Integer; norm_layer = BatchNorm, kwargs...) + _, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; - norm_layer, stride) + norm_layer, stride, no_skip = true, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, nrepeats end -function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, - residual_fns::AbstractVector; inplanes::Integer, - scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) - bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) - for (stage_idx, residual_fn) in enumerate(residual_fns)] +# TODO - these builders need to be more flexible to potentially specify stuff like +# activation functions and reductions that don't change +function _get_builder(::typeof(dwsep_conv_bn), block_configs, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) + @assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument" + return idx -> dwsepconv_builder(block_configs, inplanes, idx, width_mult; norm_layer, + kwargs...) +end + +function _get_builder(::Union{typeof(mbconv), typeof(mbconv_m3)}, block_configs, + inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) + if isnothing(scalings) + return idx -> mbconv_builder(block_configs, inplanes, idx, width_mult; norm_layer, + kwargs...) + elseif isnothing(width_mult) + return idx -> mbconv_builder(block_configs, inplanes, idx, scalings; norm_layer, + kwargs...) + else + throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified")) + end +end + +function _get_builder(::typeof(fused_mbconv), block_configs, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer) + @assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument." + @assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument" + return idx -> fused_mbconv_builder(block_configs, inplanes, idx; norm_layer) +end + +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, + norm_layer = BatchNorm, kwargs...) + bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes; scalings, + width_mult, norm_layer, kwargs...)(idx) + for idx in eachindex(block_configs)] return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl index 2bf03746f..8343bf811 100644 --- a/src/convnets/builders/resblocks.jl +++ b/src/convnets/builders/resblocks.jl @@ -68,4 +68,4 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; return block, downsample end return get_layers -end \ No newline at end of file +end diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index c37fb8c93..947080481 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,5 +1,4 @@ -function efficientnet(block_configs::AbstractVector{<:Tuple}, - residual_fns::AbstractVector; inplanes::Integer, +function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer, scalings::NTuple{2, Real} = (1, 1), headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, @@ -10,12 +9,12 @@ function efficientnet(block_configs::AbstractVector{<:Tuple}, conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8), swish; norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns; - inplanes, scalings, norm_layer) + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; scalings, + norm_layer) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 495b8658f..aaf958025 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -1,14 +1,22 @@ # block configs for EfficientNet +# data organised as (k, c, e, s, n, r, a) for each stage +# k: kernel size +# c: output channels +# e: expansion ratio +# s: stride +# n: number of repeats +# r: reduction ratio for squeeze-excite layer +# a: activation function const EFFICIENTNET_BLOCK_CONFIGS = [ - # k, c, e, s, n, r, a - (3, 16, 1, 1, 1, 4, swish), - (3, 24, 6, 2, 2, 4, swish), - (5, 40, 6, 2, 2, 4, swish), - (3, 80, 6, 2, 3, 4, swish), - (5, 112, 6, 1, 3, 4, swish), - (5, 192, 6, 2, 4, 4, swish), - (3, 320, 6, 1, 1, 4, swish), + (mbconv, 3, 16, 1, 1, 1, 4, swish), + (mbconv, 3, 24, 6, 2, 2, 4, swish), + (mbconv, 5, 40, 6, 2, 2, 4, swish), + (mbconv, 3, 80, 6, 2, 3, 4, swish), + (mbconv, 5, 112, 6, 1, 3, 4, swish), + (mbconv, 5, 192, 6, 2, 4, 4, swish), + (mbconv, 3, 320, 6, 1, 1, 4, swish), ] + # Data is organised as (r, (w, d)) # r: image resolution # w: width scaling @@ -44,9 +52,8 @@ function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Intege nclasses::Integer = 1000) _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] - layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS, - fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS)); - inplanes = 32, scalings, inchannels, nclasses) + layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings, + inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index 69f67cc81..188875ebf 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,36 +1,39 @@ # block configs for EfficientNetv2 -# data organised as (k, c, e, s, n, r, a) -const EFFNETV2_CONFIGS = Dict(:small => [ - (3, 24, 1, 1, 2, nothing, swish), - (3, 48, 4, 2, 4, nothing, swish), - (3, 64, 4, 2, 4, nothing, swish), - (3, 128, 4, 2, 6, 4, swish), - (3, 160, 6, 1, 9, 4, swish), - (3, 256, 6, 2, 15, 4, swish)], - :medium => [ - (3, 24, 1, 1, 3, nothing, swish), - (3, 48, 4, 2, 5, nothing, swish), - (3, 80, 4, 2, 5, nothing, swish), - (3, 160, 4, 2, 7, 4, swish), - (3, 176, 6, 1, 14, 4, swish), - (3, 304, 6, 2, 18, 4, swish), - (3, 512, 6, 1, 5, 4, swish)], - :large => [ - (3, 32, 1, 1, 4, nothing, swish), - (3, 64, 4, 2, 7, nothing, swish), - (3, 96, 4, 2, 7, nothing, swish), - (3, 192, 4, 2, 10, 4, swish), - (3, 224, 6, 1, 19, 4, swish), - (3, 384, 6, 2, 25, 4, swish), - (3, 640, 6, 1, 7, 4, swish)], - :xlarge => [ - (3, 32, 1, 1, 4, nothing, swish), - (3, 64, 4, 2, 8, nothing, swish), - (3, 96, 4, 2, 8, nothing, swish), - (3, 192, 4, 2, 16, 4, swish), - (3, 384, 6, 1, 24, 4, swish), - (3, 512, 6, 2, 32, 4, swish), - (3, 768, 6, 1, 8, 4, swish)]) +# data organised as (k, c, e, s, n, r, a) for each stage +# k: kernel size +# c: output channels +# e: expansion ratio +# s: stride +# n: number of repeats +# r: reduction ratio for squeeze-excite layer - specified only for `mbconv` +# a: activation function +const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish), + (fused_mbconv, 3, 48, 4, 2, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 4, swish), + (mbconv, 3, 128, 4, 2, 6, 4, swish), + (mbconv, 3, 160, 6, 1, 9, 4, swish), + (mbconv, 3, 256, 6, 2, 15, 4, swish)], + :medium => [(fused_mbconv, 3, 24, 1, 1, 3, swish), + (fused_mbconv, 3, 48, 4, 2, 5, swish), + (fused_mbconv, 3, 80, 4, 2, 5, swish), + (mbconv, 3, 160, 4, 2, 7, 4, swish), + (mbconv, 3, 176, 6, 1, 14, 4, swish), + (mbconv, 3, 304, 6, 2, 18, 4, swish), + (mbconv, 3, 512, 6, 1, 5, 4, swish)], + :large => [(fused_mbconv, 3, 32, 1, 1, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 7, swish), + (fused_mbconv, 3, 96, 4, 2, 7, swish), + (mbconv, 3, 192, 4, 2, 10, 4, swish), + (mbconv, 3, 224, 6, 1, 19, 4, swish), + (mbconv, 3, 384, 6, 2, 25, 4, swish), + (mbconv, 3, 640, 6, 1, 7, 4, swish)], + :xlarge => [(fused_mbconv, 3, 32, 1, 1, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 8, swish), + (fused_mbconv, 3, 96, 4, 2, 8, swish), + (mbconv, 3, 192, 4, 2, 16, 4, swish), + (mbconv, 3, 384, 6, 1, 24, 4, swish), + (mbconv, 3, 512, 6, 2, 32, 4, swish), + (mbconv, 3, 768, 6, 1, 8, 4, swish)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, @@ -55,11 +58,9 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnet(EFFNETV2_CONFIGS[config], - vcat(fill(fused_mbconv_builder, 3), - fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); - inplanes = EFFNETV2_CONFIGS[config][1][2], headplanes = 1280, - inchannels, nclasses) + block_configs = EFFNETV2_CONFIGS[config] + layers = efficientnet(block_configs; inplanes = block_configs[1][3], + headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2-", config)) end diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 542edec81..024f7060b 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -25,33 +25,28 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] - for (dw, outchannels, stride, nrepeats) in config - outchannels = floor(Int, outchannels * width_mult) - for _ in 1:nrepeats - layer = dw ? - dwsep_conv_bn((3, 3), inchannels, outchannels, activation; - stride, pad = 1) : - conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1) - append!(layers, layer) - inchannels = outchannels - end - end - return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate)) + # stem of the model + append!(layers, + conv_norm((3, 3), inchannels, config[1][3], activation; stride = 2, pad = 1)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(config, config[1][3]; width_mult) + append!(layers, cnn_stages(get_layers, block_repeats)) + return Chain(Chain(layers...), + create_classifier(config[end][3], nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 const MOBILENETV1_CONFIGS = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), + # k, c, s, r + (dwsep_conv_bn, 3, 64, 1, 1, relu6), + (dwsep_conv_bn, 3, 128, 2, 1, relu6), + (dwsep_conv_bn, 3, 128, 1, 1, relu6), + (dwsep_conv_bn, 3, 256, 2, 1, relu6), + (dwsep_conv_bn, 3, 256, 1, 1, relu6), + (dwsep_conv_bn, 3, 512, 2, 1, relu6), + (dwsep_conv_bn, 3, 512, 1, 5, relu6), + (dwsep_conv_bn, 3, 1024, 2, 1, relu6), + (dwsep_conv_bn, 3, 1024, 1, 1, relu6), ] """ diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index ad41ff967..c233b3d5e 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -33,29 +33,26 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, - fill(mbconv_builder, - length(block_configs)); - inplanes) + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][2], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3], 8), outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 const MOBILENETV2_CONFIGS = [ - # k, c, e, s, n, r, a - (3, 16, 1, 1, 1, nothing, relu6), - (3, 24, 6, 2, 2, nothing, relu6), - (3, 32, 6, 2, 3, nothing, relu6), - (3, 64, 6, 2, 4, nothing, relu6), - (3, 96, 6, 1, 3, nothing, relu6), - (3, 160, 6, 2, 3, nothing, relu6), - (3, 320, 6, 1, 1, nothing, relu6), + # f, k, c, e, s, n r, a + (mbconv, 3, 16, 1, 1, 1, nothing, swish), + (mbconv, 3, 24, 6, 2, 2, nothing, swish), + (mbconv, 3, 32, 6, 2, 3, nothing, swish), + (mbconv, 3, 64, 6, 2, 4, nothing, swish), + (mbconv, 3, 96, 6, 1, 3, nothing, swish), + (mbconv, 3, 160, 6, 2, 3, nothing, swish), + (mbconv, 3, 320, 6, 1, 1, nothing, swish), ] """ diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 82265e125..ed8dda08b 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -25,77 +25,58 @@ Create a MobileNetv3 model. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, reduced_tail::Bool = false, - tail_dilated::Bool = false, dropout_rate = 0.2, + max_width::Integer = 1024, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) - explanes = 0 - nstages = length(configs) - reduced_divider = 1 # building inverted residual blocks - for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs) - dilation = 1 - if nstages - i <= 2 - if reduced_tail - reduced_divider = 2 - c /= reduced_divider - end - if tail_dilated - dilation = 2 - end - end - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - mbconv((k, k), inplanes, explanes, outplanes, activation; - stride, reduction, dilation)) - inplanes = outplanes - end + get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult) + append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers - headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8) - append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) + explanes = _round_channels(configs[end][3] * width_mult, 8) + midplanes = _round_channels(explanes * configs[end][4], 8) + headplanes = _round_channels(max_width * width_mult, 8) + append!(layers, conv_norm((1, 1), explanes, midplanes, hardswish)) return Chain(Chain(layers...), - create_classifier(explanes, headplanes, nclasses, + create_classifier(midplanes, headplanes, nclasses, (hardswish, identity); dropout_rate)) end # Layer configurations for small and large models for MobileNetv3 +# Data is organised as (f, k, c, e, s, n, r, a) +# f: mbconv block function - we use `mbconv_m3` for all blocks +# k: kernel size +# c: output channels +# e: expansion factor +# s: stride +# n: number of repeats +# r: squeeze and excite reduction factor +# a: activation function const MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, r, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), + # f, k, c, e, s, n, r, a + (mbconv_m3, 3, 16, 1, 2, 1, 4, relu), + (mbconv_m3, 3, 24, 4.5, 2, 1, nothing, relu), + (mbconv_m3, 3, 24, 3.67, 1, 1, nothing, relu), + (mbconv_m3, 5, 40, 4, 2, 1, 4, hardswish), + (mbconv_m3, 5, 40, 6, 1, 2, 4, hardswish), + (mbconv_m3, 5, 48, 3, 1, 2, 4, hardswish), + (mbconv_m3, 5, 96, 6, 1, 3, 4, hardswish), ], :large => [ - # k, t, c, r, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), + # f, k, c, e, s, n, r, a + (mbconv_m3, 3, 16, 1, 1, 1, nothing, relu), + (mbconv_m3, 3, 24, 4, 2, 1, nothing, relu), + (mbconv_m3, 3, 24, 3, 1, 1, nothing, relu), + (mbconv_m3, 5, 40, 3, 2, 1, 4, relu), + (mbconv_m3, 5, 40, 3, 1, 2, 4, relu), + (mbconv_m3, 3, 80, 6, 2, 1, nothing, hardswish), + (mbconv_m3, 3, 80, 2.5, 1, 1, nothing, hardswish), + (mbconv_m3, 3, 80, 2.3, 1, 2, nothing, hardswish), + (mbconv_m3, 3, 112, 6, 1, 2, 4, hardswish), + (mbconv_m3, 5, 160, 6, 1, 3, 4, hardswish), ]) """ diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 83ab9da04..9e54ec06d 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -71,7 +71,8 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm) + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, + revnorm) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, stride, pad = 1, groups = cardinality) conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) @@ -92,7 +93,8 @@ function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer norm_layer = BatchNorm, revnorm::Bool = false) pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm)...) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, + revnorm)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -161,8 +163,7 @@ on how to use this function. function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, replace_pool::Bool = false, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false) - @assert stem_type in [:default, :deep, :deep_tiered] - "Stem type must be one of [:default, :deep, :deep_tiered]" + _checkconfig(stem_type, [:default, :deep, :deep_tiered]) # Main stem deep_stem = stem_type == :deep || stem_type == :deep_tiered inplanes = deep_stem ? stem_width * 2 : 64 @@ -199,7 +200,7 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte connection, classifier_fn) # Build stages of the ResNet stage_blocks = cnn_stages(get_layers, block_repeats, connection) - backbone = Chain(stem, stage_blocks) + backbone = Chain(stem, stage_blocks...) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] return Chain(backbone, classifier_fn(nfeaturemaps)) @@ -207,12 +208,14 @@ end function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); - cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, + cardinality::Integer = 1, base_width::Integer = 64, + inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, + use_conv::Bool = false, drop_block_rate = nothing, + drop_path_rate = nothing, dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) @@ -234,12 +237,12 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, planes_fn = resnet_planes, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert isnothing(drop_block_rate) - "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" - @assert isnothing(drop_path_rate) - "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" - @assert reduction_factor == 1 - "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" + @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. + Set `drop_block_rate` to nothing." + @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. + Set `drop_path_rate` to nothing." + @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. + Set `reduction_factor` to 1." get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 9bdf1f913..e1b5197f0 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -28,7 +28,7 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mbconv.jl") -export mbconv, fused_mbconv +export mbconv, mbconv_m3, fused_mbconv include("mlp.jl") export mlp_block, gated_mlp_block diff --git a/src/layers/conv.jl b/src/layers/conv.jl index d81d76c9c..e49611280 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -64,7 +64,6 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end - function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...) inplanes, outplanes = ch @@ -78,50 +77,3 @@ function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = r return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, eps = 1.0f-3, kwargs...) end - -""" - dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, - stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) - -Create a depthwise separable convolution chain as used in MobileNetv1. -This is sequence of layers: - - - a `kernel_size` depthwise convolution from `inplanes => inplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise - `activation` is applied to the convolution output) - - a `kernel_size` convolution from `inplanes => outplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise - `activation` is applied to the convolution output) - -See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - -# Arguments - - - `kernel_size`: size of the convolution kernel (tuple) - - `inplanes`: number of input feature maps - - `outplanes`: number of output feature maps - - `activation`: the activation function for the final layer - - `revnorm`: set to `true` to place the batch norm before the convolution - - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and - second convolution - - `bias`: a tuple of two booleans to specify whether to use bias for the first and second - convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and - `use_norm[1] == true`. - - `stride`: stride of the first convolution kernel - - `pad`: padding of the first convolution kernel - - `dilation`: dilation of the first convolution kernel - - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) -""" -function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, - outplanes::Integer, activation = relu; eps::Float32 = 1.0f-5, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), - bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) - return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, - revnorm, use_norm = use_norm[1], stride, bias = bias[1], - groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; eps, - revnorm, use_norm = use_norm[2], bias = bias[2])) -end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 387b562ef..edff55234 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -93,10 +93,8 @@ end trainable(a::DropBlock) = (;) function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_scale) - @assert 0 ≤ drop_block_prob ≤ 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert 0 ≤ gamma_scale ≤ 1 - return "gamma_scale must be between 0 and 1, got $gamma_scale" + @assert 0≤drop_block_prob≤1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0≤gamma_scale≤1 "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index eeac44dbd..74938436d 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -1,3 +1,49 @@ +""" + dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, [bias, weight, init]) + +Create a depthwise separable convolution chain as used in MobileNetv1. +This is sequence of layers: + + - a `kernel_size` depthwise convolution from `inplanes => inplanes` + - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise + `activation` is applied to the convolution output) + - a `kernel_size` convolution from `inplanes => outplanes` + - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise + `activation` is applied to the convolution output) + +See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). + +# Arguments + + - `kernel_size`: size of the convolution kernel (tuple) + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + - `activation`: the activation function for the final layer + - `revnorm`: set to `true` to place the batch norm before the convolution + - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and + second convolution + - `bias`: a tuple of two booleans to specify whether to use bias for the first and second + convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and + `use_norm[1] == true`. + - `stride`: stride of the first convolution kernel + - `pad`: padding of the first convolution kernel + - `dilation`: dilation of the first convolution kernel + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) +""" +function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, + revnorm, use_norm = use_norm[1], stride, bias = bias[1], + groups = inplanes, kwargs...), + conv_norm((1, 1), inplanes, outplanes, activation; eps, + revnorm, use_norm = use_norm[2], bias = bias[2])) +end + # TODO add support for stochastic depth to mbconv and fused_mbconv """ mbconv(kernel_size, inplanes::Integer, explanes::Integer, @@ -21,8 +67,13 @@ Create a basic inverted residual block for MobileNet variants function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, - norm_layer = BatchNorm) + norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, + no_skip::Bool = false) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + if !isnothing(momentum) + @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) + end layers = [] # expand if inplanes != explanes @@ -30,7 +81,6 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end # depthwise - stride = dilation > 1 ? 1 : stride append!(layers, conv_norm(kernel_size, explanes, explanes, activation; norm_layer, stride, dilation, pad = SamePad(), groups = explanes)) @@ -42,25 +92,57 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, end # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) - return Chain(layers...) + use_skip = stride == 1 && inplanes == outplanes && !no_skip + return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) +end + +function mbconv_m3(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, + no_skip::Bool = false) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + if !isnothing(momentum) + @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) + end + layers = [] + # expand + if inplanes != explanes + append!(layers, + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) + end + # depthwise + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, dilation, pad = SamePad(), groups = explanes)) + # squeeze-excite layer + if !isnothing(reduction) + push!(layers, + squeeze_excite(explanes, _round_channels(explanes ÷ reduction, 8); + activation, + gate_activation = hardσ)) + end + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) + use_skip = stride == 1 && inplanes == outplanes && !no_skip + return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) end function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, norm_layer = BatchNorm) + stride::Integer, norm_layer = BatchNorm, no_skip::Bool = false) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `fused_mbconv`" layers = [] + # fused expand + explanes = explanes == inplanes ? outplanes : explanes + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) if explanes != inplanes - # fused expand - append!(layers, - conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, - pad = SamePad())) # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) - else - append!(layers, - conv_norm(kernel_size, inplanes, outplanes, activation; pad = SamePad(), - norm_layer, stride)) end - return Chain(layers...) + use_skip = stride == 1 && inplanes == outplanes && !no_skip + return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index db034d1af..c9f4873e9 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -22,6 +22,8 @@ Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE- - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer """ +# TODO look into a `get_norm_act` layer that will return a closure over the norm layer +# with the activation function passed in when the norm layer is not `identity` function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; norm_layer = planes -> identity, activation = relu, gate_activation = sigmoid) @@ -34,9 +36,8 @@ function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; gate_activation] return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end - -function squeeze_excite(inplanes::Integer; reduction::Integer = 16, rd_divisor::Integer = 8, - kwargs...) +function squeeze_excite(inplanes::Integer; reduction::Integer = 16, + rd_divisor::Integer = 8, kwargs...) return squeeze_excite(inplanes, _round_channels(inplanes ÷ reduction, rd_divisor, 0); kwargs...) end @@ -54,6 +55,5 @@ Effective squeeze-and-excitation layer. """ function effective_squeeze_excite(inplanes::Integer; gate_activation = sigmoid) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes, inplanes), - gate_activation), .*) + Conv((1, 1), inplanes => inplanes, gate_activation)), .*) end From 9e917833c1addf0bd7305d894f6c69b18fd60fc7 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 4 Aug 2022 18:16:47 +0530 Subject: [PATCH 16/34] Initial commit for EfficientNetv2 --- src/Metalhead.jl | 13 ++-- .../{ => efficientnet}/efficientnet.jl | 25 +++--- src/convnets/efficientnet/efficientnetv2.jl | 78 +++++++++++++++++++ src/convnets/mobilenet/mobilenetv2.jl | 3 +- src/convnets/mobilenet/mobilenetv3.jl | 11 ++- src/convnets/resnets/resnext.jl | 3 +- src/layers/conv.jl | 2 +- src/layers/mlp.jl | 7 +- 8 files changed, 114 insertions(+), 28 deletions(-) rename src/convnets/{ => efficientnet}/efficientnet.jl (82%) create mode 100644 src/convnets/efficientnet/efficientnetv2.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index aa236454c..cad13afc9 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -33,6 +33,9 @@ include("convnets/inception/inceptionv3.jl") include("convnets/inception/inceptionv4.jl") include("convnets/inception/inceptionresnetv2.jl") include("convnets/inception/xception.jl") +## EfficientNets +include("convnets/efficientnet/efficientnet.jl") +include("convnets/efficientnet/efficientnetv2.jl") ## MobileNets include("convnets/mobilenet/mobilenetv1.jl") include("convnets/mobilenet/mobilenetv2.jl") @@ -40,7 +43,6 @@ include("convnets/mobilenet/mobilenetv3.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") -include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") @@ -61,13 +63,14 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, - SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, + SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, EfficientNetv2, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, - :Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, - :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, +for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, + :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, + :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, + :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :EfficientNetv2, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl similarity index 82% rename from src/convnets/efficientnet.jl rename to src/convnets/efficientnet/efficientnet.jl index 91986fb92..795fbd592 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -29,28 +29,27 @@ function efficientnet(scalings::NTuple{2, Real}, wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - out_channels = _round_channels(scalew(32), 8) - stem = conv_norm((3, 3), inchannels, out_channels, swish; bias = false, stride = 2, + outplanes = _round_channels(scalew(32), 8) + stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, pad = SamePad()) blocks = [] for (n, k, s, e, i, o) in block_configs - in_channels = _round_channels(scalew(i), 8) - out_channels = _round_channels(scalew(o), 8) + inchannels = _round_channels(scalew(i), 8) + outplanes = _round_channels(scalew(o), 8) repeats = scaled(n) push!(blocks, - invertedresidual((k, k), in_channels, out_channels, swish; expansion = e, + invertedresidual((k, k), in_channels, outplanes, swish; expansion = e, stride = s, reduction = 4)) for _ in 1:(repeats - 1) push!(blocks, - invertedresidual((k, k), out_channels, out_channels, swish; expansion = e, + invertedresidual((k, k), outplanes, outplanes, swish; expansion = e, stride = 1, reduction = 4)) end end - head_out_channels = _round_channels(max_width, 8) + headplanes = _round_channels(max_width, 8) append!(blocks, - conv_norm((1, 1), out_channels, head_out_channels, swish; - bias = false, pad = SamePad())) - return Chain(Chain(stem..., blocks...), create_classifier(head_out_channels, nclasses)) + conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) + return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) end # n: # of block repetitions @@ -101,9 +100,11 @@ struct EfficientNet end @functor EfficientNet -function EfficientNet(config::Symbol; pretrain::Bool = false) +function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS) + model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS; + inchannels, nclasses) if pretrain loadpretrain!(model, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl new file mode 100644 index 000000000..58a7aea46 --- /dev/null +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -0,0 +1,78 @@ +function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, + width_mult::Real = 1.0, inchannels::Integer = 3, + nclasses::Integer = 1000) + # building first layer + inplanes = _round_channels(24 * width_mult, 8) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, + bias = false)) + # building inverted residual blocks + for (t, c, n, s, r) in config + outplanes = _round_channels(c * width_mult, 8) + for i in 1:n + push!(layers, + invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, + stride = i == 1 ? s : 1, + reduction = r == 1 ? 4 : nothing)) + inplanes = outplanes + end + end + # building last layers + outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, 8) : + max_width + append!(layers, conv_norm((1, 1), inplanes, outplanes, swish; bias = false)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) +end + +const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE + (1, 24, 2, 1, 0), + (4, 48, 4, 2, 0), + (4, 64, 4, 2, 0), + (4, 128, 6, 2, 1), + (6, 160, 9, 1, 1), + (6, 256, 15, 2, 1)], + :medium => [# t, c, n, s, SE + (1, 24, 3, 1, 0), + (4, 48, 5, 2, 0), + (4, 80, 5, 2, 0), + (4, 160, 7, 2, 1), + (6, 176, 14, 1, 1), + (6, 304, 18, 2, 1), + (6, 512, 5, 1, 1)], + :large => [# t, c, n, s, SE + (1, 32, 4, 1, 0), + (4, 64, 8, 2, 0), + (4, 96, 8, 2, 0), + (4, 192, 16, 2, 1), + (6, 256, 24, 1, 1), + (6, 512, 32, 2, 1), + (6, 640, 8, 1, 1)], + :xlarge => [# t, c, n, s, SE + (1, 32, 4, 1, 0), + (4, 64, 8, 2, 0), + (4, 96, 8, 2, 0), + (4, 192, 16, 2, 1), + (6, 256, 24, 1, 1), + (6, 512, 32, 2, 1), + (6, 640, 8, 1, 1)]) + +struct EfficientNetv2 + layers::Any +end +@functor EfficientNetv2 + +function EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) + layers = efficientnetv2(EFFNETV2_CONFIGS[config]; width_mult, inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnetv2")) + end + return EfficientNetv2(layers) +end + +(m::EfficientNetv2)(x) = m.layers(x) + +backbone(m::EfficientNetv2) = m.layers[1] +classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 84162e985..c73644073 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -29,7 +29,8 @@ function mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; # building first layer inplanes = _round_channels(32 * width_mult, divisor) layers = [] - append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + append!(layers, + conv_norm((3, 3), inchannels, inplanes; bias = false, pad = 1, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, divisor) diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 7d06ab14d..607069bdd 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -24,14 +24,14 @@ Create a MobileNetv3 model. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes """ -function mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; +function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1024, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, + conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1, bias = false)) explanes = 0 # building inverted residual blocks @@ -45,9 +45,8 @@ function mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; inplanes = outplanes end # building last layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel + output_channel = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : + max_width append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, output_channel, hardswish), @@ -119,7 +118,7 @@ function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = fals inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, [:small, :large]) max_width = (config == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[config]; max_width, inchannels, + layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv3", config)) diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 8c43d2f62..bb589b97b 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -32,7 +32,8 @@ function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = _checkconfig(depth, keys(LRESNET_CONFIGS)) layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain - loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d")) + loadpretrain!(layers, + string("resnext", depth, "_", cardinality, "x", base_width, "d")) end return ResNeXt(layers) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index de214bcbc..c272d724a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -136,7 +136,7 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer norm_layer = BatchNorm) invres = Chain(conv1..., conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad = pad, groups = hidden_planes)..., + bias = false, stride, pad, groups = hidden_planes)..., selayer, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 0e496097b..500c31811 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -73,12 +73,15 @@ function create_classifier(inplanes::Integer, nclasses::Integer, activation = id "`pool_layer` must be identity if `use_conv` is true" end classifier = [] - flatten_in_pool ? push!(classifier, pool_layer, MLUtils.flatten) : + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else push!(classifier, pool_layer) + end # Dropout is applied after the pooling layer isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) # Fully-connected layer use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) : - push!(classifier, Dense(inplanes => nclasses, activation)) + push!(classifier, Dense(inplanes => nclasses, activation)) return Chain(classifier...) end From 4a94569190d50b2e980f5e5aeb51eebaf0b99afd Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 5 Aug 2022 11:39:44 +0530 Subject: [PATCH 17/34] Cleanup --- src/convnets/densenet.jl | 3 ++- src/convnets/efficientnet/efficientnet.jl | 2 +- src/convnets/inception/inceptionresnetv2.jl | 7 +++---- src/convnets/inception/inceptionv4.jl | 3 ++- src/convnets/mobilenet/mobilenetv2.jl | 12 ++++++------ src/convnets/mobilenet/mobilenetv3.jl | 17 ++++++++--------- src/convnets/resnets/core.jl | 12 ++++++------ src/convnets/resnets/resnext.jl | 8 +++++--- src/layers/conv.jl | 21 +++++++++++++-------- src/layers/mlp.jl | 7 +++++-- src/utilities.jl | 2 +- src/vit-based/vit.jl | 2 +- 12 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index eb29c4966..badb61a9e 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -132,7 +132,8 @@ end function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32, reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(DENSENET_CONFIGS)) - layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses) + layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, + nclasses) if pretrain loadpretrain!(layers, string("densenet", config)) end diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 795fbd592..518ff15f2 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -38,7 +38,7 @@ function efficientnet(scalings::NTuple{2, Real}, outplanes = _round_channels(scalew(o), 8) repeats = scaled(n) push!(blocks, - invertedresidual((k, k), in_channels, outplanes, swish; expansion = e, + invertedresidual((k, k), inchannels, outplanes, swish; expansion = e, stride = s, reduction = 4)) for _ in 1:(repeats - 1) push!(blocks, diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index a8bfbaefa..54bc59952 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -75,7 +75,7 @@ Creates an InceptionResNetv2 model. - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, +function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., conv_norm((3, 3), 32, 32)..., @@ -96,8 +96,8 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, end """ - InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) + InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -106,7 +106,6 @@ Creates an InceptionResNetv2 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl index cd4971742..7f027da6e 100644 --- a/src/convnets/inception/inceptionv4.jl +++ b/src/convnets/inception/inceptionv4.jl @@ -121,7 +121,8 @@ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, end """ - Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) + Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index c73644073..531a7c9da 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -1,5 +1,5 @@ """ - mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; + mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1280, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -8,9 +8,6 @@ Create a MobileNetv2 model. # Arguments - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - `configs`: A "list of tuples" configuration for each layer that details: + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer @@ -18,11 +15,14 @@ Create a MobileNetv2 model. + `n`: The number of times a block is repeated + `s`: The stride of the convolutional kernel + `a`: The activation function used in the bottleneck layer + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes """ -function mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; +function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1280, inchannels::Integer = 3, nclasses::Integer = 1000) divisor = width_mult == 0.1 ? 4 : 8 @@ -86,7 +86,7 @@ end function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) + layers = mobilenetv2(MOBILENETV2_CONFIGS; width_mult, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv2")) end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 607069bdd..01acf9c54 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -1,5 +1,5 @@ """ - mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; + mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1024, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -8,10 +8,6 @@ Create a MobileNetv3 model. # Arguments - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `configs`: a "list of tuples" configuration for each layer that details: + `k::Integer` - The size of the convolutional kernel @@ -20,6 +16,9 @@ Create a MobileNetv3 model. + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + `s::Integer` - The stride of the convolutional kernel + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4.) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes @@ -45,13 +44,13 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = outplanes end # building last layers - output_channel = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : + headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : max_width append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(explanes, output_channel, hardswish), + Dense(explanes, headplanes, hardswish), Dropout(0.2), - Dense(output_channel, nclasses)) + Dense(headplanes, nclasses)) return Chain(Chain(layers...), classifier) end @@ -117,7 +116,7 @@ end function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, [:small, :large]) - max_width = (config == :large) ? 1280 : 1024 + max_width = config == :large ? 1280 : 1024 layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels, nclasses) if pretrain diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 699edcbe8..afc446d14 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -85,15 +85,15 @@ end # Downsample layer using convolutions. function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, revnorm = false) + norm_layer = BatchNorm, revnorm::Bool = false) return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, pad = SamePad(), stride, bias = false)...) end # Downsample layer using max pooling function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, revnorm = false) - pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) + norm_layer = BatchNorm, revnorm::Bool = false) + pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, bias = false)...) @@ -123,7 +123,7 @@ const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), # Stride for each block in the ResNet model function resnet_stride(stage_idx::Integer, block_idx::Integer) - return (stage_idx == 1 || block_idx != 1) ? 1 : 2 + return stage_idx == 1 || block_idx != 1 ? 1 : 2 end # returns `DropBlock`s for each stage of the ResNet as in timm. @@ -221,7 +221,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_fn = stride != 1 || inplanes != planes * expansion ? downsample_tuple[1] : downsample_tuple[2] drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) @@ -256,7 +256,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_fn = stride != 1 || inplanes != planes * expansion ? downsample_tuple[1] : downsample_tuple[2] drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index bb589b97b..664202d15 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -8,11 +8,13 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. # Arguments - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. Supported configurations are: - - depth 50, cardinality of 32 and base width of 4. - - depth 101, cardinality of 32 and base width of 8. - - depth 101, cardinality of 64 and base width of 4. + + + depth 50, cardinality of 32 and base width of 4. + + depth 101, cardinality of 32 and base width of 8. + + depth 101, cardinality of 64 and base width of 4. - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. - `inchannels`: the number of input channels. diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c272d724a..cdbfc472c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -107,8 +107,13 @@ function depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Inte end """ - invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; - stride, reduction = nothing) + invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Integer} = nothing) + + invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, + activation = relu; stride::Integer, expansion::Real, + reduction::Union{Nothing, Integer} = nothing) Create a basic inverted residual block for MobileNet variants ([reference](https://arxiv.org/abs/1905.02244)). @@ -117,7 +122,8 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer + - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, + specify the keyword argument `expansion`, which calculates - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 @@ -129,7 +135,7 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer reduction::Union{Nothing, Integer} = nothing) @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 - conv1 = (inplanes == hidden_planes) ? (identity,) : + conv1 = inplanes == hidden_planes ? (identity,) : conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false) selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, @@ -139,13 +145,12 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer bias = false, stride, pad, groups = hidden_planes)..., selayer, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) - return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres + return stride == 1 && inplanes == outplanes ? SkipConnection(invres, +) : invres end function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) - hidden_planes = floor(Int, inplanes * expansion) - return invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation; - stride, reduction) + return invertedresidual(kernel_size, inplanes, floor(Int, inplanes * expansion), + outplanes, activation; stride, reduction) end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 500c31811..467df30a4 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -81,7 +81,10 @@ function create_classifier(inplanes::Integer, nclasses::Integer, activation = id # Dropout is applied after the pooling layer isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) # Fully-connected layer - use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) : - push!(classifier, Dense(inplanes => nclasses, activation)) + if use_conv + push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) + else + push!(classifier, Dense(inplanes => nclasses, activation)) + end return Chain(classifier...) end diff --git a/src/utilities.jl b/src/utilities.jl index f5737831c..4a611b5a2 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -6,7 +6,7 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) function _round_channels(channels, divisor, min_value = divisor) new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) # Make sure that round down does not go down by more than 10% - return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels + return new_channels < 0.9 * channels ? new_channels + divisor : new_channels end """ diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 099d00639..75bfb5b07 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -61,7 +61,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, Dropout(emb_dropout_rate), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout_rate), - (pool == :class) ? x -> x[:, 1, :] : seconddimmean), + pool == :class ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end From 69563a6b7fc591c5f273a880b3d3987707ffe02b Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 08:38:59 +0530 Subject: [PATCH 18/34] Add docs for EfficientNetv2 Also misc. docs and formatting --- src/convnets/convmixer.jl | 4 +- src/convnets/convnext.jl | 2 +- src/convnets/efficientnet/efficientnetv2.jl | 50 ++++++++++++++++++++- src/convnets/inception/inceptionv3.jl | 3 +- src/convnets/inception/xception.jl | 2 +- src/convnets/mobilenet/mobilenetv1.jl | 4 +- src/convnets/mobilenet/mobilenetv2.jl | 4 +- src/convnets/mobilenet/mobilenetv3.jl | 8 ++-- src/convnets/resnets/core.jl | 18 ++++---- src/convnets/resnets/res2net.jl | 23 +++++----- src/convnets/resnets/resnet.jl | 2 +- src/convnets/resnets/resnext.jl | 3 +- src/convnets/resnets/seresnet.jl | 14 +++--- src/layers/conv.jl | 6 ++- src/layers/embeddings.jl | 2 +- 15 files changed, 99 insertions(+), 46 deletions(-) diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 309989d2d..c7dd058ff 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -13,7 +13,7 @@ Creates a ConvMixer model. - `kernel_size`: kernel size of the convolutional layers - `patch_size`: size of the patches - `activation`: activation function used after the convolutional layers - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), @@ -48,7 +48,7 @@ Creates a ConvMixer model. # Arguments - `config`: the size of the model, either `:base`, `:small` or `:large` - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ struct ConvMixer diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 040a409ab..15271cfed 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -84,7 +84,7 @@ Creates a ConvNeXt model. # Arguments - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`. - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of output classes See also [`Metalhead.convnext`](#). diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl index 58a7aea46..40265bd2c 100644 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -1,3 +1,28 @@ +""" + efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, + width_mult::Real = 1.0, inchannels::Integer = 3, + nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `t`: expansion factor of the block + + `c`: output channels of the block (will be scaled by width_mult) + + `n`: number of block repetitions + + `s`: kernel stride in the block except the first block of each stage + + `se`: whether to use a `squeeze_excite` layer in the block or not + + - `max_width`: maximum number of output channels before the fully connected + classification blocks + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, width_mult::Real = 1.0, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -8,13 +33,13 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, bias = false)) # building inverted residual blocks - for (t, c, n, s, r) in config + for (t, c, n, s, se) in config outplanes = _round_channels(c * width_mult, 8) for i in 1:n push!(layers, invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, stride = i == 1 ? s : 1, - reduction = r == 1 ? 4 : nothing)) + reduction = se == 1 ? 4 : nothing)) inplanes = outplanes end end @@ -25,6 +50,12 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) end +# config dict of inverted residual blocks for EfficientNetv2 +# t: expansion factor of the block +# c: output channels of the block (will be scaled by width_mult) +# n: number of block repetitions +# s: kernel stride in the block except the first block of each stage +# se: whether to use a `squeeze_excite` layer in the block or not const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE (1, 24, 2, 1, 0), (4, 48, 4, 2, 0), @@ -57,6 +88,21 @@ const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE (6, 512, 32, 2, 1), (6, 640, 8, 1, 1)]) +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `width_mult`: Controls the number of output feature maps in each block (with 1 + being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" struct EfficientNetv2 layers::Any end diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl index bc5ec3a2b..4f13d3695 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inception/inceptionv3.jl @@ -133,7 +133,8 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes """ -function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) +function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, + nclasses::Integer = 1000) backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., conv_norm((3, 3), 32, 32)..., conv_norm((3, 3), 32, 64; pad = 1)..., diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index 4964f3ca1..d4751352c 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -8,7 +8,7 @@ Create an Xception block. # Arguments - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `outchannels`: number of output channels. - `nrepeats`: number of repeats of depthwise separable convolution layers. - `stride`: stride by which to downsample the input. diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index b6d9fe8ee..b390a3f55 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -7,7 +7,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) + (with 1 being the default in the paper) - `configs`: A "list of tuples" configuration for each layer that details: @@ -63,7 +63,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 531a7c9da..fd5bc6691 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -17,7 +17,7 @@ Create a MobileNetv2 model. + `a`: The activation function used in the bottleneck layer - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) + (with 1 being the default in the paper) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes @@ -71,7 +71,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: Whether to load the pre-trained weights for ImageNet - `inchannels`: The number of input channels. diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 01acf9c54..82c5fb187 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -18,7 +18,7 @@ Create a MobileNetv3 model. + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4.) + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes @@ -45,7 +45,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end # building last layers headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : - max_width + max_width append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), @@ -100,10 +100,10 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `config`: :small or :large for the size of the model (see paper). - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: whether to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: the number of output classes See also [`Metalhead.mobilenetv3`](#). diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index afc446d14..1e6bb9fee 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -12,7 +12,7 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385 - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `reduction_factor`: the factor by which the input feature maps are reduced before + - `reduction_factor`: the factor by which the input feature maps are reduced before the first convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. @@ -117,9 +117,9 @@ end # Shortcut configurations for the ResNet models const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), - :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv), - :D => (downsample_pool, downsample_identity)) + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) # Stride for each block in the ResNet model function resnet_stride(stage_idx::Integer, block_idx::Integer) @@ -156,8 +156,8 @@ on how to use this function. convolution. This is an experimental variant from the `timm` library in Python that shows peformance improvements over the `:deep` stem in some cases. - - `inchannels`: The number of channels in the input. - - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + + - `inchannels`: number of input channels + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. @@ -323,9 +323,9 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`" - @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`" - @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`" + @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0" + @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0" + @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 8e054da82..b5dae6663 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -8,6 +8,7 @@ Creates a bottleneck block as described in the Res2Net paper. ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block @@ -33,17 +34,15 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, for _ in 1:max(1, scale - 1)] reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) - tuplify = if is_first - x -> tuple(x...) - else - x -> tuple(x[1], tuple(x[2:end]...)) - end - layers = [conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm, bias = false)..., - chunk$(; size = width, dims = 3), tuplify, reslayer, - conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm, bias = false)..., - attn_fn(outplanes)] + tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) + layers = [ + conv_norm((1, 1), inplanes => width * scale, activation; + norm_layer, revnorm, bias = false)..., + chunk$(; size = width, dims = 3), tuplify, reslayer, + conv_norm((1, 1), width * scale => outplanes, activation; + norm_layer, revnorm, bias = false)..., + attn_fn(outplanes), + ] return Chain(filter(!=(identity), layers)...) end @@ -86,6 +85,7 @@ Creates a Res2Net model with the specified depth, scale, and base width. ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `scale`: the number of feature groups in the block. See the @@ -125,6 +125,7 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `scale`: the number of feature groups in the block. See the diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index cdccddd4b..f935c3b93 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -46,7 +46,7 @@ The number of channels in outer 1x1 convolutions is the same. - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the Wide ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: The number of input channels. - - `nclasses`: the number of output classes + - `nclasses`: The number of output classes Advanced users who want more configuration options will be better served by using [`resnet`](#). """ diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 664202d15..20fc912a2 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -32,7 +32,8 @@ end function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) - layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, + base_width) if pretrain loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d")) diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index da074e57d..ff39921b0 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -22,8 +22,6 @@ struct SEResNet end @functor SEResNet -(m::SEResNet)(x) = m.layers(x) - function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(RESNET_CONFIGS)) @@ -35,6 +33,8 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = return SEResNet(layers) end +(m::SEResNet)(x) = m.layers(x) + backbone(m::SEResNet) = m.layers[1] classifier(m::SEResNet) = m.layers[2] @@ -65,12 +65,12 @@ struct SEResNeXt end @functor SEResNeXt -(m::SEResNeXt)(x) = m.layers(x) - function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, - base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) + base_width::Integer = 4, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) - layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, + base_width, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) @@ -78,5 +78,7 @@ function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer return SEResNeXt(layers) end +(m::SEResNeXt)(x) = m.layers(x) + backbone(m::SEResNeXt) = m.layers[1] classifier(m::SEResNeXt) = m.layers[2] diff --git a/src/layers/conv.jl b/src/layers/conv.jl index cdbfc472c..7c7fe20af 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -122,8 +122,10 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, - specify the keyword argument `expansion`, which calculates + - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, + specify the keyword argument `expansion`, which calculates the number of feature + maps in the hidden layer from the number of input feature maps as: + `hidden_planes = inplanes * expansion` - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index cb9b8378c..abdab4b44 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -11,7 +11,7 @@ patches. # Arguments - `imsize`: the size of the input image - - `inchannels`: the number of channels in the input. + - `inchannels`: number of input channels - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding - `norm_layer`: the normalization layer - by default the identity function but otherwise takes a From 70841f628374d67e9c5428e5a7a50d7663d68745 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 08:57:04 +0530 Subject: [PATCH 19/34] Add tests --- .github/workflows/CI.yml | 3 ++- test/convnets.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5e14d6c49..37cda3263 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,7 +28,8 @@ jobs: suite: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - - '["EfficientNet"]' + - '"EfficientNet"' + - '"EfficientNetv2"' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' - '[r"Res2Net", r"Res2NeXt"]' diff --git a/test/convnets.jl b/test/convnets.jl index 6d7dab496..31d68d6d1 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -177,6 +177,20 @@ end end end +@testset "EfficientNetv2" begin + @testset for config in [:small, :medium, :large, :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) From 245eda062e5497a3652ce79cdebc1b0b0fc4d3af Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 16:12:06 +0530 Subject: [PATCH 20/34] Fix Inception bug, and other misc. cleanup --- src/convnets/densenet.jl | 4 +- src/convnets/efficientnet/efficientnet.jl | 10 +- src/convnets/efficientnet/efficientnetv2.jl | 83 +++++++------- src/convnets/inception/inceptionresnetv2.jl | 8 +- src/convnets/inception/inceptionv3.jl | 82 ++++++------- src/convnets/inception/inceptionv4.jl | 120 +++++++++----------- src/convnets/inception/xception.jl | 8 +- src/convnets/mobilenet/mobilenetv1.jl | 12 +- src/convnets/mobilenet/mobilenetv2.jl | 18 +-- src/convnets/mobilenet/mobilenetv3.jl | 10 +- src/layers/Layers.jl | 4 +- src/layers/conv.jl | 63 ++++++---- 12 files changed, 211 insertions(+), 211 deletions(-) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index badb61a9e..75e1ffde1 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -10,8 +10,8 @@ Create a Densenet bottleneck layer - `outplanes`: number of output feature maps on bottleneck branch (and scaling factor for inner feature maps; see ref) """ -function dense_bottleneck(inplanes::Integer, outplanes::Integer) - inner_channels = 4 * outplanes +function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) + inner_channels = expansion * outplanes return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 518ff15f2..677c57529 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -17,10 +17,9 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). + `e`: expansion ratio + `i`: block input channels (will be scaled by global width scaling) + `o`: block output channels (will be scaled by global width scaling) + - `max_width`: The maximum number of feature maps in any layer of the network - `inchannels`: number of input channels - `nclasses`: number of output classes - - `max_width`: maximum number of output channels before the fully connected - classification blocks """ function efficientnet(scalings::NTuple{2, Real}, block_configs::AbstractVector{NTuple{6, Int}}; @@ -52,12 +51,7 @@ function efficientnet(scalings::NTuple{2, Real}, return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) end -# n: # of block repetitions -# k: kernel size k x k -# s: stride -# e: expantion ratio -# i: block input channels -# o: block output channels +# block configs for EfficientNet const EFFICIENTNET_BLOCK_CONFIGS = [ # (n, k, s, e, i, o) (1, 3, 1, 1, 32, 16), diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl index 40265bd2c..07f827095 100644 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -1,5 +1,5 @@ """ - efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, + efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, width_mult::Real = 1.0, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -14,16 +14,15 @@ Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + `c`: output channels of the block (will be scaled by width_mult) + `n`: number of block repetitions + `s`: kernel stride in the block except the first block of each stage - + `se`: whether to use a `squeeze_excite` layer in the block or not + + `r`: reduction factor of the squeeze-excite layer - - `max_width`: maximum number of output channels before the fully connected - classification blocks + - `max_width`: The maximum number of feature maps in any layer of the network - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper) - `inchannels`: number of input channels - `nclasses`: number of output classes """ -function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integer = 1792, +function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, width_mult::Real = 1.0, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer @@ -33,13 +32,12 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, bias = false)) # building inverted residual blocks - for (t, c, n, s, se) in config + for (t, c, n, s, reduction) in config outplanes = _round_channels(c * width_mult, 8) for i in 1:n push!(layers, invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, - stride = i == 1 ? s : 1, - reduction = se == 1 ? 4 : nothing)) + stride = i == 1 ? s : 1, reduction)) inplanes = outplanes end end @@ -50,43 +48,38 @@ function efficientnetv2(config::AbstractVector{NTuple{5, Int}}; max_width::Integ return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) end -# config dict of inverted residual blocks for EfficientNetv2 -# t: expansion factor of the block -# c: output channels of the block (will be scaled by width_mult) -# n: number of block repetitions -# s: kernel stride in the block except the first block of each stage -# se: whether to use a `squeeze_excite` layer in the block or not -const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, SE - (1, 24, 2, 1, 0), - (4, 48, 4, 2, 0), - (4, 64, 4, 2, 0), - (4, 128, 6, 2, 1), - (6, 160, 9, 1, 1), - (6, 256, 15, 2, 1)], - :medium => [# t, c, n, s, SE - (1, 24, 3, 1, 0), - (4, 48, 5, 2, 0), - (4, 80, 5, 2, 0), - (4, 160, 7, 2, 1), - (6, 176, 14, 1, 1), - (6, 304, 18, 2, 1), - (6, 512, 5, 1, 1)], - :large => [# t, c, n, s, SE - (1, 32, 4, 1, 0), - (4, 64, 8, 2, 0), - (4, 96, 8, 2, 0), - (4, 192, 16, 2, 1), - (6, 256, 24, 1, 1), - (6, 512, 32, 2, 1), - (6, 640, 8, 1, 1)], - :xlarge => [# t, c, n, s, SE - (1, 32, 4, 1, 0), - (4, 64, 8, 2, 0), - (4, 96, 8, 2, 0), - (4, 192, 16, 2, 1), - (6, 256, 24, 1, 1), - (6, 512, 32, 2, 1), - (6, 640, 8, 1, 1)]) +# block configs for EfficientNetv2 +const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, r + (1, 24, 2, 1, nothing), + (4, 48, 4, 2, nothing), + (4, 64, 4, 2, nothing), + (4, 128, 6, 2, 4), + (6, 160, 9, 1, 4), + (6, 256, 15, 2, 4)], + :medium => [# t, c, n, s, r + (1, 24, 3, 1, nothing), + (4, 48, 5, 2, nothing), + (4, 80, 5, 2, nothing), + (4, 160, 7, 2, 4), + (6, 176, 14, 1, 4), + (6, 304, 18, 2, 4), + (6, 512, 5, 1, 4)], + :large => [# t, c, n, s, r + (1, 32, 4, 1, nothing), + (4, 64, 8, 2, nothing), + (4, 96, 8, 2, nothing), + (4, 192, 16, 2, 4), + (6, 256, 24, 1, 4), + (6, 512, 32, 2, 4), + (6, 640, 8, 1, 4)], + :xlarge => [# t, c, n, s, r + (1, 32, 4, 1, nothing), + (4, 64, 8, 2, nothing), + (4, 96, 8, 2, nothing), + (4, 192, 16, 2, 4), + (6, 256, 24, 1, 4), + (6, 512, 32, 2, 4), + (6, 640, 8, 1, 4)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index 54bc59952..a99973d09 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -34,8 +34,8 @@ end function block17(scale = 1.0f0) branch1 = Chain(conv_norm((1, 1), 1088, 192)...) branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((1, 7), 128, 160; pad = (0, 3))..., - conv_norm((7, 1), 160, 192; pad = (3, 0))...) + conv_norm((7, 1), 128, 160; pad = (3, 0))..., + conv_norm((1, 7), 160, 192; pad = (0, 3))...) branch3 = Chain(conv_norm((1, 1), 384, 1088)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation = relu)), +) @@ -56,8 +56,8 @@ end function block8(scale = 1.0f0; activation = identity) branch1 = Chain(conv_norm((1, 1), 2080, 192)...) branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((1, 3), 192, 224; pad = (0, 1))..., - conv_norm((3, 1), 224, 256; pad = (1, 0))...) + conv_norm((3, 1), 192, 224; pad = (1, 0))..., + conv_norm((1, 3), 224, 256; pad = (0, 1))...) branch3 = Chain(conv_norm((1, 1), 448, 2080)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation)), +) diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl index 4f13d3695..32fbbede5 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inception/inceptionv3.jl @@ -10,14 +10,14 @@ Create an Inception-v3 style-A module - `pool_proj`: the number of output feature maps for the pooling projection """ function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 64)...) + branch5x5 = Chain(basic_conv_bn((1, 1), inplanes, 48)..., + basic_conv_bn((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(basic_conv_bn((1, 1), inplanes, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, pool_proj)...) + basic_conv_bn((1, 1), inplanes, pool_proj)...) return Parallel(cat_channels, branch1x1, branch5x5, branch3x3, branch_pool) end @@ -33,10 +33,10 @@ Create an Inception-v3 style-B module - `inplanes`: number of input feature maps """ function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; stride = 2)...) + branch3x3_1 = Chain(basic_conv_bn((3, 3), inplanes, 384; stride = 2)...) + branch3x3_2 = Chain(basic_conv_bn((1, 1), inplanes, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3_1, branch3x3_2, branch_pool) @@ -55,17 +55,17 @@ Create an Inception-v3 style-C module - `n`: the "grid size" (kernel size) for the convolution layers """ function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 192)...) + branch7x7_1 = Chain(basic_conv_bn((1, 1), inplanes, inner_planes)..., + basic_conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + basic_conv_bn((1, n), inner_planes, 192; pad = (0, 3))...) + branch7x7_2 = Chain(basic_conv_bn((1, 1), inplanes, inner_planes)..., + basic_conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., + basic_conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + basic_conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., + basic_conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) + basic_conv_bn((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, branch7x7_1, branch7x7_2, branch_pool) end @@ -81,12 +81,12 @@ Create an Inception-v3 style-D module - `inplanes`: number of input feature maps """ function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((1, 7), 192, 192; pad = (0, 3))..., - conv_norm((7, 1), 192, 192; pad = (3, 0))..., - conv_norm((3, 3), 192, 192; stride = 2)...) + branch3x3 = Chain(basic_conv_bn((1, 1), inplanes, 192)..., + basic_conv_bn((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(basic_conv_bn((1, 1), inplanes, 192)..., + basic_conv_bn((7, 1), 192, 192; pad = (3, 0))..., + basic_conv_bn((1, 7), 192, 192; pad = (0, 3))..., + basic_conv_bn((3, 3), 192, 192; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3, branch7x7x3, branch_pool) @@ -103,16 +103,16 @@ Create an Inception-v3 style-E module - `inplanes`: number of input feature maps """ function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., - conv_norm((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 320)...) + branch3x3_1 = Chain(basic_conv_bn((1, 1), inplanes, 384)...) + branch3x3_1a = Chain(basic_conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_1b = Chain(basic_conv_bn((1, 3), 384, 384; pad = (0, 1))...) + branch3x3_2 = Chain(basic_conv_bn((1, 1), inplanes, 448)..., + basic_conv_bn((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(basic_conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_2b = Chain(basic_conv_bn((1, 3), 384, 384; pad = (0, 1))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) + basic_conv_bn((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, Chain(branch3x3_1, @@ -135,12 +135,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). """ function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_norm((1, 1), 64, 80)..., - conv_norm((3, 3), 80, 192)..., + basic_conv_bn((1, 1), 64, 80)..., + basic_conv_bn((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), inceptionv3_a(192, 32), inceptionv3_a(256, 64), diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl index 7f027da6e..b43f6bc1d 100644 --- a/src/convnets/inception/inceptionv4.jl +++ b/src/convnets/inception/inceptionv4.jl @@ -1,83 +1,86 @@ function mixed_3a() return Parallel(cat_channels, MaxPool((3, 3); stride = 2), - Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) + Chain(basic_conv_bn((3, 3), 64, 96; stride = 2)...)) end function mixed_4a() return Parallel(cat_channels, - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((3, 3), 64, 96)...), - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((1, 7), 64, 64; pad = (0, 3))..., - conv_norm((7, 1), 64, 64; pad = (3, 0))..., - conv_norm((3, 3), 64, 96)...)) + Chain(basic_conv_bn((1, 1), 160, 64)..., + basic_conv_bn((3, 3), 64, 96)...), + Chain(basic_conv_bn((1, 1), 160, 64)..., + basic_conv_bn((7, 1), 64, 64; pad = (3, 0))..., + basic_conv_bn((1, 7), 64, 64; pad = (0, 3))..., + basic_conv_bn((3, 3), 64, 96)...)) end function mixed_5a() return Parallel(cat_channels, - Chain(conv_norm((3, 3), 192, 192; stride = 2)...), + Chain(basic_conv_bn((3, 3), 192, 192; stride = 2)...), MaxPool((3, 3); stride = 2)) end function inceptionv4_a() - branch1 = Chain(conv_norm((1, 1), 384, 96)...) - branch2 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) + branch1 = Chain(basic_conv_bn((1, 1), 384, 96)...) + branch2 = Chain(basic_conv_bn((1, 1), 384, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(basic_conv_bn((1, 1), 384, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 384, 96)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function reduction_a() - branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 384, 192)..., - conv_norm((3, 3), 192, 224; pad = 1)..., - conv_norm((3, 3), 224, 256; stride = 2)...) + branch1 = Chain(basic_conv_bn((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 384, 192)..., + basic_conv_bn((3, 3), 192, 224; pad = 1)..., + basic_conv_bn((3, 3), 224, 256; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function inceptionv4_b() - branch1 = Chain(conv_norm((1, 1), 1024, 384)...) - branch2 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((1, 7), 192, 224; pad = (0, 3))..., - conv_norm((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((7, 1), 192, 192; pad = (0, 3))..., - conv_norm((1, 7), 192, 224; pad = (3, 0))..., - conv_norm((7, 1), 224, 224; pad = (0, 3))..., - conv_norm((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) + branch1 = Chain(basic_conv_bn((1, 1), 1024, 384)...) + branch2 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((7, 1), 192, 224; pad = (0, 3))..., + basic_conv_bn((1, 7), 224, 256; pad = (3, 0))...) + branch3 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((1, 7), 192, 192; pad = (3, 0))..., + basic_conv_bn((7, 1), 192, 224; pad = (0, 3))..., + basic_conv_bn((1, 7), 224, 224; pad = (3, 0))..., + basic_conv_bn((7, 1), 224, 256; pad = (0, 3))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 1024, 128)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function reduction_b() - branch1 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1024, 256)..., - conv_norm((1, 7), 256, 256; pad = (0, 3))..., - conv_norm((7, 1), 256, 320; pad = (3, 0))..., - conv_norm((3, 3), 320, 320; stride = 2)...) + branch1 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 1024, 256)..., + basic_conv_bn((7, 1), 256, 256; pad = (3, 0))..., + basic_conv_bn((1, 7), 256, 320; pad = (0, 3))..., + basic_conv_bn((3, 3), 320, 320; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function inceptionv4_c() - branch1 = Chain(conv_norm((1, 1), 1536, 256)...) - branch2 = Chain(conv_norm((1, 1), 1536, 384)..., + branch1 = Chain(basic_conv_bn((1, 1), 1536, 256)...) + branch2 = Chain(basic_conv_bn((1, 1), 1536, 384)..., Parallel(cat_channels, - Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_norm((1, 1), 1536, 384)..., - conv_norm((3, 1), 384, 448; pad = (1, 0))..., - conv_norm((1, 3), 448, 512; pad = (0, 1))..., + Chain(basic_conv_bn((3, 1), 384, 256; pad = (1, 0))...), + Chain(basic_conv_bn((1, 3), 384, 256; pad = (0, 1))...))) + branch3 = Chain(basic_conv_bn((1, 1), 1536, 384)..., + basic_conv_bn((1, 3), 384, 448; pad = (0, 1))..., + basic_conv_bn((3, 1), 448, 512; pad = (1, 0))..., Parallel(cat_channels, - Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) + Chain(basic_conv_bn((3, 1), 512, 256; pad = (1, 0))...), + Chain(basic_conv_bn((1, 3), 512, 256; pad = (0, 1))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 1536, 256)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end @@ -95,28 +98,15 @@ Create an Inceptionv4 model. """ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - mixed_3a(), - mixed_4a(), - mixed_5a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., + mixed_3a(), mixed_4a(), mixed_5a(), + [inceptionv4_a() for _ in 1:4]..., reduction_a(), # mixed_6a - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), + [inceptionv4_b() for _ in 1:7]..., reduction_b(), # mixed_7a - inceptionv4_c(), - inceptionv4_c(), - inceptionv4_c()) + [inceptionv4_c() for _ in 1:3]...) return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index d4751352c..14b5444d6 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -35,8 +35,8 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, - use_norm = (false, false))) + dwsep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, + use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -64,8 +64,8 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ xception_block(256, 728, 2; stride = 2), [xception_block(728, 728, 3) for _ in 1:8]..., xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...) + dwsep_conv_bn((3, 3), 1024, 1536; pad = 1)..., + dwsep_conv_bn((3, 3), 1536, 2048; pad = 1)...) return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index b390a3f55..db9dedbdb 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -22,16 +22,16 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] - for (dw, outch, stride, nrepeats) in config - outch = floor(Int, outch * width_mult) + for (dw, outchannels, stride, nrepeats) in config + outchannels = floor(Int, outchannels * width_mult) for _ in 1:nrepeats layer = dw ? - depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, + dwsep_conv_bn((3, 3), inchannels, outchannels, activation; + stride, pad = 1, bias = false) : + conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1, bias = false) append!(layers, layer) - inchannels = outch + inchannels = outchannels end end return Chain(Chain(layers...), create_classifier(inchannels, nclasses)) diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index fd5bc6691..39db8933b 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -1,7 +1,7 @@ """ mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv2 model. ([reference](https://arxiv.org/abs/1801.04381)). @@ -18,14 +18,15 @@ Create a MobileNetv2 model. - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper) - - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network + - `divisor`: The divisor used to round the number of feature maps in each block + - `dropout_rate`: rate of dropout in the classifier head + - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) - divisor = width_mult == 0.1 ? 4 : 8 + max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(32 * width_mult, divisor) layers = [] @@ -42,10 +43,9 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end end # building last layers - outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, divisor) : - max_width + outplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 82c5fb187..c3bca7a95 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -34,13 +34,13 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, bias = false)) explanes = 0 # building inverted residual blocks - for (k, t, c, r, a, s) in configs + for (k, t, c, reduction, activation, stride) in configs # inverted residual layers outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, - invertedresidual((k, k), inplanes, explanes, outplanes, a; - stride = s, reduction = r)) + invertedresidual((k, k), inplanes, explanes, outplanes, activation; + stride, reduction)) inplanes = outplanes end # building last layers @@ -56,7 +56,7 @@ end # Layer configurations for small and large models for MobileNetv3 const MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, SE, a, s + # k, t, c, r, a, s (3, 1, 16, 4, relu, 2), (3, 4.5, 24, nothing, relu, 2), (3, 3.67, 24, nothing, relu, 1), @@ -70,7 +70,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => [ (5, 6, 96, 4, hardswish, 1), ], :large => [ - # k, t, c, SE, a, s + # k, t, c, r, a, s (3, 1, 16, nothing, relu, 1), (3, 4, 24, nothing, relu, 2), (3, 3, 24, nothing, relu, 1), diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 04be476ff..ec48ab772 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -11,13 +11,15 @@ using MLUtils using PartialFunctions using Random +import Flux.testmode! + include("../utilities.jl") include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, depthwise_sep_conv_norm, invertedresidual +export conv_norm, basic_conv_bn, dwsep_conv_bn, invertedresidual include("drop.jl") export DropBlock, DropPath diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7c7fe20af..fcccae691 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -28,22 +28,34 @@ Create a convolution + batch normalization pair with activation. - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, + momentum::Union{Nothing, Float32} = nothing, preact::Bool = false, use_norm::Bool = true, kwargs...) + # handle momentum for BatchNorm + if norm_layer == BatchNorm + momentum = isnothing(momentum) ? 0.1f0 : momentum + norm_layer = (args...; kargs...) -> BatchNorm(args...; momentum, kargs...) + elseif norm_layer != BatchNorm && !isnothing(momentum) + error("momentum is only supported for BatchNorm") + end + # no normalization layer if !use_norm - if (preact || revnorm) + if preact || revnorm throw(ArgumentError("`preact` only supported with `use_norm = true`")) else + # early return if no norm layer is required return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] end end + # channels for norm layer and activation functions for both conv and norm if revnorm activations = (conv = activation, bn = identity) - bnplanes = inplanes + normplanes = inplanes else activations = (conv = identity, bn = activation) - bnplanes = outplanes + normplanes = outplanes end + # handle pre-activation if preact if revnorm throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time")) @@ -51,8 +63,9 @@ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activatio activations = (conv = activation, bn = identity) end end + # layers layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), - norm_layer(bnplanes, activations.bn)] + norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end @@ -62,8 +75,14 @@ function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = ide return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end +# conv + bn layer combination as used by the inception model family +function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) + return conv_norm(kernel_size, inplanes, outplanes, activation; eps = 1.0f-3, + bias = false, kwargs...) +end + """ - depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, + dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), @@ -95,15 +114,15 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), kwargs...) +function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, momentum::Float32 = 0.1f0, + revnorm::Bool = false, stride::Integer = 1, + use_norm::NTuple{2, Bool} = (true, true), kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; - norm_layer, revnorm, use_norm = use_norm[1], stride, + revnorm, use_norm = use_norm[1], stride, groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, - use_norm = use_norm[2])) + conv_norm((1, 1), inplanes, outplanes, activation; + revnorm, use_norm = use_norm[2])) end """ @@ -134,7 +153,8 @@ Create a basic inverted residual block for MobileNet variants """ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing) + reduction::Union{Nothing, Integer} = nothing, + momentum::Float32 = 0.1f0) @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 conv1 = inplanes == hidden_planes ? (identity,) : @@ -142,17 +162,18 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, norm_layer = BatchNorm) - invres = Chain(conv1..., - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad, groups = hidden_planes)..., - selayer, - conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) - return stride == 1 && inplanes == outplanes ? SkipConnection(invres, +) : invres + invres = [conv1..., + conv_norm(kernel_size, hidden_planes, hidden_planes, activation; + bias = false, stride, pad, groups = hidden_planes, momentum)..., + selayer, + conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...] + layers = Chain(filter(!=(identity), invres)...) + return stride == 1 && inplanes == outplanes ? SkipConnection(layers, +) : layers end function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) - return invertedresidual(kernel_size, inplanes, floor(Int, inplanes * expansion), + return invertedresidual(kernel_size, inplanes, round(Int, inplanes * expansion), outplanes, activation; stride, reduction) end From 1c65159efa14438b434696bd1034b982acdb600f Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 11 Aug 2022 23:54:04 +0530 Subject: [PATCH 21/34] Refactor: `mbconv` instead of `invertedresidual` Also fix bug in EfficientNet models --- src/convnets/efficientnet/efficientnet.jl | 18 +++-- src/convnets/efficientnet/efficientnetv2.jl | 79 +++++++++--------- src/convnets/inception/inceptionresnetv2.jl | 82 +++++++++---------- src/convnets/mobilenet/mobilenetv2.jl | 4 +- src/convnets/mobilenet/mobilenetv3.jl | 10 +-- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 90 ++++++++++++--------- src/layers/selayers.jl | 36 ++++++--- test/convnets.jl | 2 +- 9 files changed, 178 insertions(+), 145 deletions(-) diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 677c57529..4f06d81b7 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -17,35 +17,37 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). + `e`: expansion ratio + `i`: block input channels (will be scaled by global width scaling) + `o`: block output channels (will be scaled by global width scaling) - - `max_width`: The maximum number of feature maps in any layer of the network - `inchannels`: number of input channels - `nclasses`: number of output classes """ function efficientnet(scalings::NTuple{2, Real}, block_configs::AbstractVector{NTuple{6, Int}}; - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) + inchannels::Integer = 3, nclasses::Integer = 1000) + # building first layer wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) outplanes = _round_channels(scalew(32), 8) stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, pad = SamePad()) + # building inverted residual blocks blocks = [] for (n, k, s, e, i, o) in block_configs inchannels = _round_channels(scalew(i), 8) + explanes = _round_channels(inchannels * e, 8) outplanes = _round_channels(scalew(o), 8) repeats = scaled(n) push!(blocks, - invertedresidual((k, k), inchannels, outplanes, swish; expansion = e, - stride = s, reduction = 4)) + mbconv((k, k), inchannels, explanes, outplanes, swish; + stride = s, reduction = 4)) for _ in 1:(repeats - 1) push!(blocks, - invertedresidual((k, k), outplanes, outplanes, swish; expansion = e, - stride = 1, reduction = 4)) + mbconv((k, k), outplanes, explanes, outplanes, swish; + stride = 1, reduction = 4)) end end - headplanes = _round_channels(max_width, 8) + # building last layers + headplanes = outplanes * 4 append!(blocks, conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl index 07f827095..afcba5dab 100644 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ b/src/convnets/efficientnet/efficientnetv2.jl @@ -32,12 +32,19 @@ function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 17 conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, bias = false)) # building inverted residual blocks - for (t, c, n, s, reduction) in config - outplanes = _round_channels(c * width_mult, 8) + for (t, inplanes, outplanes, n, s, reduction) in config + explanes = _round_channels(inplanes * t, 8) for i in 1:n - push!(layers, - invertedresidual((3, 3), inplanes, outplanes, swish; expansion = t, - stride = i == 1 ? s : 1, reduction)) + stride = i == 1 ? s : 1 + if isnothing(reduction) + push!(layers, + fused_mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) + else + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + push!(layers, + mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) + end inplanes = outplanes end end @@ -49,37 +56,37 @@ function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 17 end # block configs for EfficientNetv2 -const EFFNETV2_CONFIGS = Dict(:small => [# t, c, n, s, r - (1, 24, 2, 1, nothing), - (4, 48, 4, 2, nothing), - (4, 64, 4, 2, nothing), - (4, 128, 6, 2, 4), - (6, 160, 9, 1, 4), - (6, 256, 15, 2, 4)], - :medium => [# t, c, n, s, r - (1, 24, 3, 1, nothing), - (4, 48, 5, 2, nothing), - (4, 80, 5, 2, nothing), - (4, 160, 7, 2, 4), - (6, 176, 14, 1, 4), - (6, 304, 18, 2, 4), - (6, 512, 5, 1, 4)], - :large => [# t, c, n, s, r - (1, 32, 4, 1, nothing), - (4, 64, 8, 2, nothing), - (4, 96, 8, 2, nothing), - (4, 192, 16, 2, 4), - (6, 256, 24, 1, 4), - (6, 512, 32, 2, 4), - (6, 640, 8, 1, 4)], - :xlarge => [# t, c, n, s, r - (1, 32, 4, 1, nothing), - (4, 64, 8, 2, nothing), - (4, 96, 8, 2, nothing), - (4, 192, 16, 2, 4), - (6, 256, 24, 1, 4), - (6, 512, 32, 2, 4), - (6, 640, 8, 1, 4)]) +const EFFNETV2_CONFIGS = Dict(:small => [ + (1, 24, 24, 2, 1, nothing), + (4, 24, 48, 4, 2, nothing), + (4, 48, 64, 4, 2, nothing), + (4, 64, 128, 6, 2, 4), + (6, 128, 160, 9, 1, 4), + (6, 160, 256, 15, 2, 4)], + :medium => [ + (1, 24, 24, 3, 1, nothing), + (4, 24, 48, 5, 2, nothing), + (4, 48, 80, 5, 2, nothing), + (4, 80, 160, 7, 2, 4), + (6, 160, 176, 14, 1, 4), + (6, 176, 304, 18, 2, 4), + (6, 304, 512, 5, 1, 4)], + :large => [ + (1, 32, 32, 4, 1, nothing), + (4, 32, 64, 7, 2, nothing), + (4, 64, 96, 7, 2, nothing), + (4, 96, 192, 10, 2, 4), + (6, 192, 224, 19, 1, 4), + (6, 224, 384, 25, 2, 4), + (6, 384, 640, 7, 1, 4)], + :xlarge => [ + (1, 32, 32, 4, 1, nothing), + (4, 32, 64, 8, 2, nothing), + (4, 64, 96, 8, 2, nothing), + (4, 96, 192, 16, 2, 4), + (6, 192, 256, 24, 1, 4), + (6, 256, 512, 32, 2, 4), + (6, 512, 640, 8, 1, 4)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index a99973d09..7f462c0cf 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -1,64 +1,64 @@ function mixed_5b() - branch1 = Chain(conv_norm((1, 1), 192, 96)...) - branch2 = Chain(conv_norm((1, 1), 192, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_norm((1, 1), 192, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) + branch1 = Chain(basic_conv_bn((1, 1), 192, 96)...) + branch2 = Chain(basic_conv_bn((1, 1), 192, 48)..., + basic_conv_bn((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(basic_conv_bn((1, 1), 192, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), 192, 64)...) + basic_conv_bn((1, 1), 192, 64)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block35(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 320, 32)...) - branch2 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 48; pad = 1)..., - conv_norm((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_norm((1, 1), 128, 320)...) + branch1 = Chain(basic_conv_bn((1, 1), 320, 32)...) + branch2 = Chain(basic_conv_bn((1, 1), 320, 32)..., + basic_conv_bn((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(basic_conv_bn((1, 1), 320, 32)..., + basic_conv_bn((3, 3), 32, 48; pad = 1)..., + basic_conv_bn((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(basic_conv_bn((1, 1), 128, 320)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), branch4, inputscale(scale; activation = relu)), +) end function mixed_6a() - branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 320, 256)..., - conv_norm((3, 3), 256, 256; pad = 1)..., - conv_norm((3, 3), 256, 384; stride = 2)...) + branch1 = Chain(basic_conv_bn((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 320, 256)..., + basic_conv_bn((3, 3), 256, 256; pad = 1)..., + basic_conv_bn((3, 3), 256, 384; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function block17(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 1088, 192)...) - branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((7, 1), 128, 160; pad = (3, 0))..., - conv_norm((1, 7), 160, 192; pad = (0, 3))...) - branch3 = Chain(conv_norm((1, 1), 384, 1088)...) + branch1 = Chain(basic_conv_bn((1, 1), 1088, 192)...) + branch2 = Chain(basic_conv_bn((1, 1), 1088, 128)..., + basic_conv_bn((7, 1), 128, 160; pad = (3, 0))..., + basic_conv_bn((1, 7), 160, 192; pad = (0, 3))...) + branch3 = Chain(basic_conv_bn((1, 1), 384, 1088)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation = relu)), +) end function mixed_7a() - branch1 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; pad = 1)..., - conv_norm((3, 3), 288, 320; stride = 2)...) + branch1 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 288; pad = 1)..., + basic_conv_bn((3, 3), 288, 320; stride = 2)...) branch4 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_norm((1, 1), 2080, 192)...) - branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((3, 1), 192, 224; pad = (1, 0))..., - conv_norm((1, 3), 224, 256; pad = (0, 1))...) - branch3 = Chain(conv_norm((1, 1), 448, 2080)...) + branch1 = Chain(basic_conv_bn((1, 1), 2080, 192)...) + branch2 = Chain(basic_conv_bn((1, 1), 2080, 192)..., + basic_conv_bn((3, 1), 192, 224; pad = (1, 0))..., + basic_conv_bn((1, 3), 224, 256; pad = (0, 1))...) + branch3 = Chain(basic_conv_bn((1, 1), 448, 2080)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation)), +) end @@ -77,12 +77,12 @@ Creates an InceptionResNetv2 model. """ function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_norm((3, 3), 64, 80)..., - conv_norm((3, 3), 80, 192)..., + basic_conv_bn((3, 3), 64, 80)..., + basic_conv_bn((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), mixed_5b(), [block35(0.17f0) for _ in 1:10]..., @@ -91,7 +91,7 @@ function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, mixed_7a(), [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), - conv_norm((1, 1), 2080, 1536)...) + basic_conv_bn((1, 1), 2080, 1536)...) return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 39db8933b..7bc87bcb9 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -37,8 +37,8 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, outplanes = _round_channels(c * width_mult, divisor) for i in 1:n push!(layers, - invertedresidual((3, 3), inplanes, outplanes, a; expansion = t, - stride = i == 1 ? s : 1)) + mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a; + stride = i == 1 ? s : 1)) inplanes = outplanes end end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index c3bca7a95..68fe2f03b 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -24,8 +24,8 @@ Create a MobileNetv3 model. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] @@ -39,8 +39,8 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, - invertedresidual((k, k), inplanes, explanes, outplanes, activation; - stride, reduction)) + mbconv((k, k), inplanes, explanes, outplanes, activation; + stride, reduction)) inplanes = outplanes end # building last layers @@ -49,7 +49,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), - Dropout(0.2), + Dropout(dropout_rate), Dense(headplanes, nclasses)) return Chain(Chain(layers...), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index ec48ab772..72ace2c2c 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -19,7 +19,7 @@ include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, basic_conv_bn, dwsep_conv_bn, invertedresidual +export conv_norm, basic_conv_bn, dwsep_conv_bn, mbconv, fused_mbconv include("drop.jl") export DropBlock, DropPath diff --git a/src/layers/conv.jl b/src/layers/conv.jl index fcccae691..087082c8b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -29,15 +29,7 @@ Create a convolution + batch normalization pair with activation. """ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, - momentum::Union{Nothing, Float32} = nothing, preact::Bool = false, - use_norm::Bool = true, kwargs...) - # handle momentum for BatchNorm - if norm_layer == BatchNorm - momentum = isnothing(momentum) ? 0.1f0 : momentum - norm_layer = (args...; kargs...) -> BatchNorm(args...; momentum, kargs...) - elseif norm_layer != BatchNorm && !isnothing(momentum) - error("momentum is only supported for BatchNorm") - end + preact::Bool = false, use_norm::Bool = true, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -75,10 +67,11 @@ function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = ide return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end -# conv + bn layer combination as used by the inception model family +# conv + bn layer combination as used by the inception model family matching +# the default values used in TensorFlow function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) - return conv_norm(kernel_size, inplanes, outplanes, activation; eps = 1.0f-3, - bias = false, kwargs...) + return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, + eps = 1.0f-3, bias = false, kwargs...) end """ @@ -115,22 +108,22 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, momentum::Float32 = 0.1f0, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), kwargs...) - return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, revnorm, use_norm = use_norm[1], stride, groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; + conv_norm((1, 1), inplanes, outplanes, activation; eps, revnorm, use_norm = use_norm[2])) end """ - invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, + mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, outplanes::Integer, activation = relu; stride::Integer, reduction::Union{Nothing, Integer} = nothing) - invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, + mbconv(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) @@ -151,29 +144,48 @@ Create a basic inverted residual block for MobileNet variants - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)). """ -function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing, - momentum::Float32 = 0.1f0) +function mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" - pad = @. (kernel_size - 1) ÷ 2 - conv1 = inplanes == hidden_planes ? (identity,) : - conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false) - selayer = isnothing(reduction) ? identity : - squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, - norm_layer = BatchNorm) - invres = [conv1..., - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad, groups = hidden_planes, momentum)..., - selayer, - conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...] - layers = Chain(filter(!=(identity), invres)...) - return stride == 1 && inplanes == outplanes ? SkipConnection(layers, +) : layers + layers = [] + # expand + if inplanes != hidden_planes + append!(layers, + conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false, + norm_layer)) + end + # squeeze-excite layer + if !isnothing(reduction) + append!(layers, + squeeze_excite(hidden_planes, inplanes ÷ reduction; activation, + gate_activation = hardσ)) + end + # depthwise + append!(layers, + conv_norm(kernel_size, hidden_planes, hidden_planes, activation; bias = false, + norm_layer, stride, pad = SamePad(), groups = hidden_planes)) + # project + append!(layers, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)) + return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : + Chain(layers...) end -function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion::Real, - reduction::Union{Nothing, Integer} = nothing) - return invertedresidual(kernel_size, inplanes, round(Int, inplanes * expansion), - outplanes, activation; stride, reduction) +function fused_mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, + activation = relu; stride::Integer, norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2" + layers = [] + if explanes != inplanes + # fused expand + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride)) + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) + else + append!(layers, + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) + end + return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : + Chain(layers...) end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 0756225ba..db034d1af 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -1,34 +1,46 @@ """ + squeeze_excite(inplanes::Integer, squeeze_planes::Integer; + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) + squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, - activation = relu, gate_activation = sigmoid, norm_layer = identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + activation = relu, gate_activation = sigmoid, norm_layer = identity) -Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. +Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets. # Arguments - `inplanes`: The number of input feature maps - - `reduction`: The reduction factor for the number of hidden feature maps - - `rd_divisor`: The divisor for the number of hidden feature maps. + - `squeeze_planes`: The number of feature maps in the intermediate layers. Alternatively, + specify the keyword arguments `reduction` and `rd_divisior`, which determine the number + of feature maps in the intermediate layers from the number of input feature maps as: + `squeeze_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)`. + (See [`_round_channels`](#) for details. The default values are `reduction = 16` and + `rd_divisor = 8`.) - `activation`: The activation function for the first convolution layer - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer """ -function squeeze_excite(inplanes::Integer; reduction::Integer = 16, - rd_divisor::Integer = 8, activation = relu, - gate_activation = sigmoid, norm_layer = planes -> identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) +function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) layers = [AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes => rd_planes), - norm_layer(rd_planes), + Conv((1, 1), inplanes => squeeze_planes), + norm_layer(squeeze_planes), activation, - Conv((1, 1), rd_planes => inplanes), + Conv((1, 1), squeeze_planes => inplanes), norm_layer(inplanes), gate_activation] return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end +function squeeze_excite(inplanes::Integer; reduction::Integer = 16, rd_divisor::Integer = 8, + kwargs...) + return squeeze_excite(inplanes, _round_channels(inplanes ÷ reduction, rd_divisor, 0); + kwargs...) +end + """ effective_squeeze_excite(inplanes, gate_activation = sigmoid) diff --git a/test/convnets.jl b/test/convnets.jl index 31d68d6d1..e087ceb0e 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -161,7 +161,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] x = rand(Float32, r, r, 3, 1) From 744f214a707e12cedf319c37953d339b68ee3aea Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 12 Aug 2022 13:51:02 +0530 Subject: [PATCH 22/34] Refactor EfficientNets --- src/Metalhead.jl | 21 +-- src/convnets/convmixer.jl | 4 +- src/convnets/densenet.jl | 8 +- src/convnets/efficientnet/efficientnet.jl | 113 ---------------- src/convnets/efficientnet/efficientnetv2.jl | 124 ------------------ src/convnets/efficientnets/core.jl | 78 +++++++++++ src/convnets/efficientnets/efficientnet.jl | 59 +++++++++ src/convnets/efficientnets/efficientnetv2.jl | 67 ++++++++++ .../{inception => inceptions}/googlenet.jl | 0 .../inceptionresnetv2.jl | 0 .../{inception => inceptions}/inceptionv3.jl | 0 .../{inception => inceptions}/inceptionv4.jl | 0 .../{inception => inceptions}/xception.jl | 10 +- .../{mobilenet => mobilenets}/mobilenetv1.jl | 5 +- .../{mobilenet => mobilenets}/mobilenetv2.jl | 4 +- .../{mobilenet => mobilenets}/mobilenetv3.jl | 5 +- src/convnets/resnets/core.jl | 24 ++-- src/convnets/resnets/res2net.jl | 6 +- src/convnets/vgg.jl | 2 +- src/layers/conv.jl | 104 ++++++++------- src/layers/drop.jl | 1 + test/convnets.jl | 4 +- 22 files changed, 301 insertions(+), 338 deletions(-) delete mode 100644 src/convnets/efficientnet/efficientnet.jl delete mode 100644 src/convnets/efficientnet/efficientnetv2.jl create mode 100644 src/convnets/efficientnets/core.jl create mode 100644 src/convnets/efficientnets/efficientnet.jl create mode 100644 src/convnets/efficientnets/efficientnetv2.jl rename src/convnets/{inception => inceptions}/googlenet.jl (100%) rename src/convnets/{inception => inceptions}/inceptionresnetv2.jl (100%) rename src/convnets/{inception => inceptions}/inceptionv3.jl (100%) rename src/convnets/{inception => inceptions}/inceptionv4.jl (100%) rename src/convnets/{inception => inceptions}/xception.jl (92%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv1.jl (95%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv2.jl (97%) rename src/convnets/{mobilenet => mobilenets}/mobilenetv3.jl (98%) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index cad13afc9..6b0179f45 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -28,18 +28,19 @@ include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") include("convnets/resnets/res2net.jl") ## Inceptions -include("convnets/inception/googlenet.jl") -include("convnets/inception/inceptionv3.jl") -include("convnets/inception/inceptionv4.jl") -include("convnets/inception/inceptionresnetv2.jl") -include("convnets/inception/xception.jl") +include("convnets/inceptions/googlenet.jl") +include("convnets/inceptions/inceptionv3.jl") +include("convnets/inceptions/inceptionv4.jl") +include("convnets/inceptions/inceptionresnetv2.jl") +include("convnets/inceptions/xception.jl") ## EfficientNets -include("convnets/efficientnet/efficientnet.jl") -include("convnets/efficientnet/efficientnetv2.jl") +include("convnets/efficientnets/core.jl") +include("convnets/efficientnets/efficientnet.jl") +include("convnets/efficientnets/efficientnetv2.jl") ## MobileNets -include("convnets/mobilenet/mobilenetv1.jl") -include("convnets/mobilenet/mobilenetv2.jl") -include("convnets/mobilenet/mobilenetv3.jl") +include("convnets/mobilenets/mobilenetv1.jl") +include("convnets/mobilenets/mobilenetv2.jl") +include("convnets/mobilenets/mobilenetv3.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index c7dd058ff..1ca8487a9 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -1,5 +1,5 @@ """ - convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), + convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -16,7 +16,7 @@ Creates a ConvMixer model. - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ -function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), +function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 75e1ffde1..ca81b78ea 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -12,10 +12,10 @@ Create a Densenet bottleneck layer """ function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) inner_channels = expansion * outplanes - return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, + return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, - bias = false, revnorm = true)...), + revnorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes::Integer, outplanes::Integer) - return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; revnorm = true)..., MeanPool((2, 2))) end @@ -72,7 +72,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: nclasses::Integer = 1000) layers = [] append!(layers, - conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false)) + conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl deleted file mode 100644 index 4f06d81b7..000000000 --- a/src/convnets/efficientnet/efficientnet.jl +++ /dev/null @@ -1,113 +0,0 @@ -""" - efficientnet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - -# Arguments - - - `scalings`: global width and depth scaling (given as a tuple) - - - `block_configs`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `n`: number of block repetitions (will be scaled by global depth scaling) - + `k`: kernel size - + `s`: kernel stride - + `e`: expansion ratio - + `i`: block input channels (will be scaled by global width scaling) - + `o`: block output channels (will be scaled by global width scaling) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -function efficientnet(scalings::NTuple{2, Real}, - block_configs::AbstractVector{NTuple{6, Int}}; - inchannels::Integer = 3, nclasses::Integer = 1000) - # building first layer - wscale, dscale = scalings - scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) - scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - outplanes = _round_channels(scalew(32), 8) - stem = conv_norm((3, 3), inchannels, outplanes, swish; bias = false, stride = 2, - pad = SamePad()) - # building inverted residual blocks - blocks = [] - for (n, k, s, e, i, o) in block_configs - inchannels = _round_channels(scalew(i), 8) - explanes = _round_channels(inchannels * e, 8) - outplanes = _round_channels(scalew(o), 8) - repeats = scaled(n) - push!(blocks, - mbconv((k, k), inchannels, explanes, outplanes, swish; - stride = s, reduction = 4)) - for _ in 1:(repeats - 1) - push!(blocks, - mbconv((k, k), outplanes, explanes, outplanes, swish; - stride = 1, reduction = 4)) - end - end - # building last layers - headplanes = outplanes * 4 - append!(blocks, - conv_norm((1, 1), outplanes, headplanes, swish; bias = false, pad = SamePad())) - return Chain(Chain(stem..., blocks...), create_classifier(headplanes, nclasses)) -end - -# block configs for EfficientNet -const EFFICIENTNET_BLOCK_CONFIGS = [ - # (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), - (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320), -] - -# w: width scaling -# d: depth scaling -# r: image resolution -# Data is organised as (r, (w, d)) -const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6))) - -""" - EfficientNet(config::Symbol; pretrain::Bool = false) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). -See also [`efficientnet`](#). - -# Arguments - - - `config`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet -""" -struct EfficientNet - layers::Any -end -@functor EfficientNet - -function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) - _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS; - inchannels, nclasses) - if pretrain - loadpretrain!(model, string("efficientnet-", config)) - end - return model -end - -(m::EfficientNet)(x) = m.layers(x) - -backbone(m::EfficientNet) = m.layers[1] -classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnet/efficientnetv2.jl b/src/convnets/efficientnet/efficientnetv2.jl deleted file mode 100644 index afcba5dab..000000000 --- a/src/convnets/efficientnet/efficientnetv2.jl +++ /dev/null @@ -1,124 +0,0 @@ -""" - efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, - width_mult::Real = 1.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - -Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - -# Arguments - - - `config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `t`: expansion factor of the block - + `c`: output channels of the block (will be scaled by width_mult) - + `n`: number of block repetitions - + `s`: kernel stride in the block except the first block of each stage - + `r`: reduction factor of the squeeze-excite layer - - - `max_width`: The maximum number of feature maps in any layer of the network - - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -function efficientnetv2(config::AbstractVector{<:Tuple}; max_width::Integer = 1792, - width_mult::Real = 1.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - # building first layer - inplanes = _round_channels(24 * width_mult, 8) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, swish; pad = 1, stride = 2, - bias = false)) - # building inverted residual blocks - for (t, inplanes, outplanes, n, s, reduction) in config - explanes = _round_channels(inplanes * t, 8) - for i in 1:n - stride = i == 1 ? s : 1 - if isnothing(reduction) - push!(layers, - fused_mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) - else - inplanes = _round_channels(inplanes * width_mult, 8) - outplanes = _round_channels(outplanes * width_mult, 8) - push!(layers, - mbconv((3, 3), inplanes, explanes, outplanes, swish; stride)) - end - inplanes = outplanes - end - end - # building last layers - outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, 8) : - max_width - append!(layers, conv_norm((1, 1), inplanes, outplanes, swish; bias = false)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) -end - -# block configs for EfficientNetv2 -const EFFNETV2_CONFIGS = Dict(:small => [ - (1, 24, 24, 2, 1, nothing), - (4, 24, 48, 4, 2, nothing), - (4, 48, 64, 4, 2, nothing), - (4, 64, 128, 6, 2, 4), - (6, 128, 160, 9, 1, 4), - (6, 160, 256, 15, 2, 4)], - :medium => [ - (1, 24, 24, 3, 1, nothing), - (4, 24, 48, 5, 2, nothing), - (4, 48, 80, 5, 2, nothing), - (4, 80, 160, 7, 2, 4), - (6, 160, 176, 14, 1, 4), - (6, 176, 304, 18, 2, 4), - (6, 304, 512, 5, 1, 4)], - :large => [ - (1, 32, 32, 4, 1, nothing), - (4, 32, 64, 7, 2, nothing), - (4, 64, 96, 7, 2, nothing), - (4, 96, 192, 10, 2, 4), - (6, 192, 224, 19, 1, 4), - (6, 224, 384, 25, 2, 4), - (6, 384, 640, 7, 1, 4)], - :xlarge => [ - (1, 32, 32, 4, 1, nothing), - (4, 32, 64, 8, 2, nothing), - (4, 64, 96, 8, 2, nothing), - (4, 96, 192, 16, 2, 4), - (6, 192, 256, 24, 1, 4), - (6, 256, 512, 32, 2, 4), - (6, 512, 640, 8, 1, 4)]) - -""" - EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - -# Arguments - - - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `width_mult`: Controls the number of output feature maps in each block (with 1 - being the default in the paper) - - `inchannels`: number of input channels - - `nclasses`: number of output classes -""" -struct EfficientNetv2 - layers::Any -end -@functor EfficientNetv2 - -function EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnetv2(EFFNETV2_CONFIGS[config]; width_mult, inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("efficientnetv2")) - end - return EfficientNetv2(layers) -end - -(m::EfficientNetv2)(x) = m.layers(x) - -backbone(m::EfficientNetv2) = m.layers[1] -classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl new file mode 100644 index 000000000..1059cb538 --- /dev/null +++ b/src/convnets/efficientnets/core.jl @@ -0,0 +1,78 @@ +abstract type _MBConfig end + +struct MBConvConfig <: _MBConfig + kernel_size::Dims{2} + inplanes::Integer + outplanes::Integer + expansion::Number + stride::Integer + nrepeats::Integer +end +function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, + expansion::Number, stride::Integer, nrepeats::Integer, + width_mult::Number = 1, depth_mult::Number = 1) + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + nrepeats = ceil(Int, nrepeats * depth_mult) + return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, + stride, nrepeats) +end + +function efficientnetblock(m::MBConvConfig, norm_layer) + layers = [] + explanes = _round_channels(m.inplanes * m.expansion, 8) + push!(layers, + mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer, + stride = m.stride, reduction = 4)) + explanes = _round_channels(m.outplanes * m.expansion, 8) + append!(layers, + [mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer, + stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)]) + return Chain(layers...) +end + +struct FusedMBConvConfig <: _MBConfig + kernel_size::Dims{2} + inplanes::Integer + outplanes::Integer + expansion::Number + stride::Integer + nrepeats::Integer +end +function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, + expansion::Number, stride::Integer, nrepeats::Integer) + return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, + stride, nrepeats) +end + +function efficientnetblock(m::FusedMBConvConfig, norm_layer) + layers = [] + explanes = _round_channels(m.inplanes * m.expansion, 8) + push!(layers, + fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; + norm_layer, stride = m.stride)) + explanes = _round_channels(m.outplanes * m.expansion, 8) + append!(layers, + [fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; + norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)]) + return Chain(layers...) +end + +function efficientnet(block_configs::AbstractVector{<:_MBConfig}; + headplanes::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) + layers = [] + # stem of the model + append!(layers, + conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer, + stride = 2, pad = SamePad())) + # building inverted residual blocks + append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs]) + # building last layers + outplanes = block_configs[end].outplanes + headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes + append!(layers, + conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad())) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) +end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl new file mode 100644 index 000000000..0bb481dda --- /dev/null +++ b/src/convnets/efficientnets/efficientnet.jl @@ -0,0 +1,59 @@ +# block configs for EfficientNet +const EFFICIENTNET_BLOCK_CONFIGS = [ + # k, i, o, e, s, n + (3, 32, 16, 1, 1, 1), + (3, 16, 24, 6, 2, 2), + (5, 24, 40, 6, 2, 2), + (3, 40, 80, 6, 2, 3), + (5, 80, 112, 6, 1, 3), + (5, 112, 192, 6, 2, 4), + (3, 192, 320, 6, 1, 1), +] + +# Data is organised as (r, (w, d)) +# r: image resolution +# w: width scaling +# d: depth scaling +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) + +""" + EfficientNet(config::Symbol; pretrain::Bool = false) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). +See also [`efficientnet`](#). + +# Arguments + + - `config`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet +""" +struct EfficientNet + layers::Any +end +@functor EfficientNet + +function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + cfg_fn = (args...) -> MBConvConfig(args..., EFFICIENTNET_GLOBAL_CONFIGS[config][2]...) + block_configs = [cfg_fn(args...) for args in EFFICIENTNET_BLOCK_CONFIGS] + layers = efficientnet(block_configs; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnet-", config)) + end + return EfficientNet(layers) +end + +(m::EfficientNet)(x) = m.layers(x) + +backbone(m::EfficientNet) = m.layers[1] +classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl new file mode 100644 index 000000000..d2d6a3222 --- /dev/null +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -0,0 +1,67 @@ +# block configs for EfficientNetv2 +const EFFNETV2_CONFIGS = Dict(:small => [ + FusedMBConvConfig(3, 24, 24, 1, 1, 2), + FusedMBConvConfig(3, 24, 48, 4, 2, 4), + FusedMBConvConfig(3, 48, 64, 4, 2, 4), + MBConvConfig(3, 64, 128, 4, 2, 6), + MBConvConfig(3, 128, 160, 6, 1, 9), + MBConvConfig(3, 160, 256, 6, 2, 15)], + :medium => [ + FusedMBConvConfig(3, 24, 24, 1, 1, 3), + FusedMBConvConfig(3, 24, 48, 4, 2, 5), + FusedMBConvConfig(3, 48, 80, 4, 2, 5), + MBConvConfig(3, 80, 160, 4, 2, 7), + MBConvConfig(3, 160, 176, 6, 1, 14), + MBConvConfig(3, 176, 304, 6, 2, 18), + MBConvConfig(3, 304, 512, 6, 1, 5)], + :large => [ + FusedMBConvConfig(3, 32, 32, 1, 1, 4), + FusedMBConvConfig(3, 32, 64, 4, 2, 7), + FusedMBConvConfig(3, 64, 96, 4, 2, 7), + MBConvConfig(3, 96, 192, 4, 2, 10), + MBConvConfig(3, 192, 224, 6, 1, 19), + MBConvConfig(3, 224, 384, 6, 2, 25), + MBConvConfig(3, 384, 640, 6, 1, 7)], + :xlarge => [ + FusedMBConvConfig(3, 32, 32, 1, 1, 4), + FusedMBConvConfig(3, 32, 64, 4, 2, 8), + FusedMBConvConfig(3, 64, 96, 4, 2, 8), + MBConvConfig(3, 96, 192, 4, 2, 16), + MBConvConfig(3, 192, 224, 6, 1, 24), + MBConvConfig(3, 384, 512, 6, 2, 32), + MBConvConfig(3, 512, 768, 6, 1, 8)]) + +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `width_mult`: Controls the number of output feature maps in each block (with 1 + being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" +struct EfficientNetv2 + layers::Any +end +@functor EfficientNetv2 + +function EfficientNetv2(config::Symbol; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) + layers = efficientnet(EFFNETV2_CONFIGS[config]; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnetv2")) + end + return EfficientNetv2(layers) +end + +(m::EfficientNetv2)(x) = m.layers(x) + +backbone(m::EfficientNetv2) = m.layers[1] +classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/inception/googlenet.jl b/src/convnets/inceptions/googlenet.jl similarity index 100% rename from src/convnets/inception/googlenet.jl rename to src/convnets/inceptions/googlenet.jl diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl similarity index 100% rename from src/convnets/inception/inceptionresnetv2.jl rename to src/convnets/inceptions/inceptionresnetv2.jl diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl similarity index 100% rename from src/convnets/inception/inceptionv3.jl rename to src/convnets/inceptions/inceptionv3.jl diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl similarity index 100% rename from src/convnets/inception/inceptionv4.jl rename to src/convnets/inceptions/inceptionv4.jl diff --git a/src/convnets/inception/xception.jl b/src/convnets/inceptions/xception.jl similarity index 92% rename from src/convnets/inception/xception.jl rename to src/convnets/inceptions/xception.jl index 14b5444d6..33222e7be 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -19,8 +19,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int stride::Integer = 1, start_with_relu::Bool = true, grow_at_start::Bool = true) if outchannels != inchannels || stride != 1 - skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride) else skip = [identity] end @@ -35,8 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - dwsep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, - use_norm = (false, false))) + dwsep_conv_bn((3, 3), inc, outc; pad = 1, use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -57,8 +55,8 @@ Creates an Xception model. - `nclasses`: the number of output classes. """ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_norm((3, 3), 32, 64; bias = false)..., + backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 64)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), xception_block(128, 256, 2; stride = 2), xception_block(256, 728, 2; stride = 2), diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl similarity index 95% rename from src/convnets/mobilenet/mobilenetv1.jl rename to src/convnets/mobilenets/mobilenetv1.jl index db9dedbdb..caa899a53 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -27,9 +27,8 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati for _ in 1:nrepeats layer = dw ? dwsep_conv_bn((3, 3), inchannels, outchannels, activation; - stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1, - bias = false) + stride, pad = 1) : + conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1) append!(layers, layer) inchannels = outchannels end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl similarity index 97% rename from src/convnets/mobilenet/mobilenetv2.jl rename to src/convnets/mobilenets/mobilenetv2.jl index 7bc87bcb9..232286309 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -31,7 +31,7 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = _round_channels(32 * width_mult, divisor) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes; bias = false, pad = 1, stride = 2)) + conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, divisor) @@ -44,7 +44,7 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)) + append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl similarity index 98% rename from src/convnets/mobilenet/mobilenetv3.jl rename to src/convnets/mobilenets/mobilenetv3.jl index 68fe2f03b..78c55e144 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -30,8 +30,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1, - bias = false)) + conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) explanes = 0 # building inverted residual blocks for (k, t, c, reduction, activation, stride) in configs @@ -46,7 +45,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, # building last layers headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : max_width - append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) + append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(explanes, headplanes, hardswish), Dropout(dropout_rate), diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 1e6bb9fee..35bb34fc4 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -29,9 +29,9 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, first_planes = planes ÷ reduction_factor outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, - stride, pad = 1, bias = false) + stride, pad = 1) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, - pad = 1, bias = false) + pad = 1) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -72,12 +72,10 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, - bias = false) + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, - stride, pad = 1, groups = cardinality, bias = false) - conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm, - bias = false) + stride, pad = 1, groups = cardinality) + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -87,7 +85,7 @@ end function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false) return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - pad = SamePad(), stride, bias = false)...) + pad = SamePad(), stride)...) end # Downsample layer using max pooling @@ -95,8 +93,7 @@ function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer norm_layer = BatchNorm, revnorm::Bool = false) pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - bias = false)...) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -178,9 +175,9 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; - norm_layer, revnorm, stride = 2, pad = 1, bias = false)..., + norm_layer, revnorm, stride = 2, pad = 1)..., conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; - norm_layer, pad = 1, bias = false)..., + norm_layer, pad = 1)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) @@ -189,8 +186,7 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - revnorm, - stride = 2, pad = 1, bias = false)...) : + revnorm, stride = 2, pad = 1)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool) end diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index b5dae6663..e308e1125 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -30,17 +30,17 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, outplanes = planes * 4 pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride, - pad = 1, groups = cardinality, bias = false)...) + pad = 1, groups = cardinality)...) for _ in 1:max(1, scale - 1)] reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) layers = [ conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm, bias = false)..., + norm_layer, revnorm)..., chunk$(; size = width, dims = 3), tuplify, reslayer, conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm, bias = false)..., + norm_layer, revnorm)..., attn_fn(outplanes), ] return Chain(filter(!=(identity), layers)...) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index de232d9a3..163c13b68 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -17,7 +17,7 @@ function vgg_block(ifilters::Integer, ofilters::Integer, depth::Integer, batchno layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_norm(k, ifilters, ofilters; pad = p, bias = false)) + append!(layers, conv_norm(k, ifilters, ofilters; pad = p)) else push!(layers, Conv(k, ifilters => ofilters, relu; pad = p)) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 087082c8b..78300b0d0 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,10 +1,11 @@ """ - conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, - use_norm::Bool = true, stride::Integer = 1, pad::Integer = 0, - dilation::Integer = 1, groups::Integer = 1, [bias, weight, init]) + conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + stride::Integer = 1, pad::Integer = 0, dilation::Integer = 1, + groups::Integer = 1, [bias, weight, init]) - conv_norm(kernel_size, inplanes => outplanes, activation = identity; + conv_norm(kernel_size::Dims{2}, inplanes => outplanes, activation = identity; kwargs...) Create a convolution + batch normalization pair with activation. @@ -25,11 +26,14 @@ Create a convolution + batch normalization pair with activation. - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `bias`: bias for the convolution kernel. This is set to `false` by default if + `use_norm = true`. + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, - preact::Bool = false, use_norm::Bool = true, kwargs...) +function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + bias = !use_norm, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -56,30 +60,30 @@ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activatio end end # layers - layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), + layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; - kwargs...) +function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, + activation = identity; kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end # conv + bn layer combination as used by the inception model family matching # the default values used in TensorFlow -function basic_conv_bn(kernel_size, inplanes, outplanes, activation = relu; kwargs...) +function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu; + kwargs...) return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, - eps = 1.0f-3, bias = false, kwargs...) + eps = 1.0f-3, kwargs...) end """ - dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) + dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: @@ -102,31 +106,32 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `revnorm`: set to `true` to place the batch norm before the convolution - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and second convolution + - `bias`: a tuple of two booleans to specify whether to use bias for the first and second + convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and + `use_norm[1] == true`. - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function dwsep_conv_bn(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, +function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, + outplanes::Integer, activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), kwargs...) + use_norm::NTuple{2, Bool} = (true, true), + bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, - revnorm, use_norm = use_norm[1], stride, + revnorm, use_norm = use_norm[1], stride, bias = bias[1], groups = inplanes, kwargs...), conv_norm((1, 1), inplanes, outplanes, activation; eps, - revnorm, use_norm = use_norm[2])) + revnorm, use_norm = use_norm[2], bias = bias[2])) end +# TODO add support for stochastic depth to mbconv and fused_mbconv """ - mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, + mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, reduction::Union{Nothing, Integer} = nothing) - mbconv(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion::Real, - reduction::Union{Nothing, Integer} = nothing) - Create a basic inverted residual block for MobileNet variants ([reference](https://arxiv.org/abs/1905.02244)). @@ -134,46 +139,43 @@ Create a basic inverted residual block for MobileNet variants - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer. Alternatively, - specify the keyword argument `expansion`, which calculates the number of feature - maps in the hidden layer from the number of input feature maps as: - `hidden_planes = inplanes * expansion` + - `explanes`: The number of feature maps in the hidden layer - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)). + in a squeeze and excite layer (see [`squeeze_excite`](#)) """ -function mbconv(kernel_size, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing, +function mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] # expand - if inplanes != hidden_planes + if inplanes != explanes append!(layers, - conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false, - norm_layer)) + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end + # depthwise + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) - append!(layers, - squeeze_excite(hidden_planes, inplanes ÷ reduction; activation, - gate_activation = hardσ)) + push!(layers, + squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, + gate_activation = hardσ)) end - # depthwise - append!(layers, - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; bias = false, - norm_layer, stride, pad = SamePad(), groups = hidden_planes)) # project - append!(layers, conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)) + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : Chain(layers...) end -function fused_mbconv(kernel_size, inplanes::Integer, explanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, norm_layer = BatchNorm) +function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] if explanes != inplanes diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 31c06c07a..b252584fe 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -116,6 +116,7 @@ function Base.show(io::IO, d::DropBlock) return print(io, ")") end +# TODO look into "row" mode for stochastic depth """ DropPath(p; [rng = rng_from_array(x)]) diff --git a/test/convnets.jl b/test/convnets.jl index e087ceb0e..34bbb5121 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -161,7 +161,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] x = rand(Float32, r, r, 3, 1) @@ -178,7 +178,7 @@ end end @testset "EfficientNetv2" begin - @testset for config in [:small, :medium, :large, :xlarge] + @testset for config in [:small, :medium, :large] # :xlarge] m = EfficientNetv2(config) @test size(m(x_224)) == (1000, 1) if (EfficientNetv2, config) in PRETRAINED_MODELS From 9bc75fc6d5a9a8776b75d1f21c0626fd94115e83 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 13 Aug 2022 07:56:59 +0530 Subject: [PATCH 23/34] Fixes --- src/convnets/efficientnets/efficientnetv2.jl | 3 ++- src/layers/conv.jl | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index d2d6a3222..6847e7fa8 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -54,7 +54,8 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnet(EFFNETV2_CONFIGS[config]; inchannels, nclasses) + layers = efficientnet(EFFNETV2_CONFIGS[config]; headplanes = 1280, inchannels, + nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 78300b0d0..c94ceb045 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -181,7 +181,8 @@ function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, if explanes != inplanes # fused expand append!(layers, - conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride)) + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) else From 3a37a7023db779c3ee346ae91e60e3e894e1239e Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 14 Aug 2022 19:24:52 +0530 Subject: [PATCH 24/34] Some refactors, some consistency, some features --- .github/workflows/CI.yml | 3 +- src/convnets/convmixer.jl | 15 ++-- src/convnets/densenet.jl | 15 ++-- src/convnets/efficientnets/core.jl | 10 +-- src/convnets/inceptions/inceptionresnetv2.jl | 6 +- src/convnets/inceptions/inceptionv4.jl | 6 +- src/convnets/inceptions/xception.jl | 4 +- src/convnets/mobilenets/mobilenetv1.jl | 9 +- src/convnets/mobilenets/mobilenetv2.jl | 9 +- src/convnets/mobilenets/mobilenetv3.jl | 37 +++++--- src/convnets/resnets/core.jl | 19 ++-- src/layers/Layers.jl | 5 +- src/layers/classifier.jl | 93 ++++++++++++++++++++ src/layers/conv.jl | 9 +- src/layers/drop.jl | 44 +++++---- src/layers/mlp.jl | 44 +-------- src/utilities.jl | 8 +- 17 files changed, 215 insertions(+), 121 deletions(-) create mode 100644 src/layers/classifier.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 37cda3263..5304bc317 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,8 +28,7 @@ jobs: suite: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - - '"EfficientNet"' - - '"EfficientNetv2"' + - '"EfficientNet"' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' - '[r"Res2Net", r"Res2NeXt"]' diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 1ca8487a9..bc1a71a5f 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -17,16 +17,21 @@ Creates a ConvMixer model. - `nclasses`: number of classes in the output """ function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), - patch_size::Dims{2} = (7, 7), activation = gelu, + patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) - stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, - stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; + layers = [] + # stem of the model + append!(layers, + conv_norm(patch_size, inchannels, planes, activation; preact = true, + stride = patch_size[1])) + # stages of the model + stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; preact = true, groups = planes, pad = SamePad())), +), conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] - return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses)) + append!(layers, stages) + return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate)) end const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index ca81b78ea..a7c367c1c 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates) end """ - densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000) + densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a DenseNet model ([reference](https://arxiv.org/abs/1608.06993)). @@ -66,10 +67,11 @@ Create a DenseNet model - `growth_rates`: the growth rates of output feature maps within each [`dense_block`](#) (a vector of vectors) - `reduction`: the factor by which the number of feature maps is scaled across each transition + - `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes """ -function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3, - nclasses::Integer = 1000) +function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] append!(layers, conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) @@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: inplanes = floor(Int, outplanes * reduction) end push!(layers, BatchNorm(outplanes, relu)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end """ @@ -100,9 +102,10 @@ Create a DenseNet model - `nclasses`: the number of output classes """ function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, - reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) + reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; - reduction, inchannels, nclasses) + reduction, dropout_rate, inchannels, nclasses) end const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 1059cb538..7a221c0e4 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -4,13 +4,13 @@ struct MBConvConfig <: _MBConfig kernel_size::Dims{2} inplanes::Integer outplanes::Integer - expansion::Number + expansion::Real stride::Integer nrepeats::Integer end function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Number, stride::Integer, nrepeats::Integer, - width_mult::Number = 1, depth_mult::Number = 1) + expansion::Real, stride::Integer, nrepeats::Integer, + width_mult::Real = 1, depth_mult::Real = 1) inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) nrepeats = ceil(Int, nrepeats * depth_mult) @@ -35,12 +35,12 @@ struct FusedMBConvConfig <: _MBConfig kernel_size::Dims{2} inplanes::Integer outplanes::Integer - expansion::Number + expansion::Real stride::Integer nrepeats::Integer end function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Number, stride::Integer, nrepeats::Integer) + expansion::Real, stride::Integer, nrepeats::Integer) return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, stride, nrepeats) end diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index 7f462c0cf..bd88648e9 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -64,7 +64,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -72,10 +72,10 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, +function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index b43f6bc1d..13d40da25 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -85,7 +85,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -93,10 +93,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, +function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 33222e7be..171bddd19 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -43,14 +43,14 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end """ - xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) + xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index caa899a53..542edec81 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -1,5 +1,6 @@ """ - mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, + mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). @@ -16,10 +17,12 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ -function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, +function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] for (dw, outchannels, stride, nrepeats) in config @@ -33,7 +36,7 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati inchannels = outchannels end end - return Chain(Chain(layers...), create_classifier(inchannels, nclasses)) + return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 232286309..d81256968 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -20,7 +20,7 @@ Create a MobileNetv2 model. (with 1 being the default in the paper) - `max_width`: The maximum number of feature maps in any layer of the network - `divisor`: The divisor used to round the number of feature maps in each block - - `dropout_rate`: rate of dropout in the classifier head + - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ @@ -33,12 +33,13 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - for (t, c, n, s, a) in configs + for (t, c, n, s, activation) in configs outplanes = _round_channels(c * width_mult, divisor) for i in 1:n + stride = i == 1 ? s : 1 push!(layers, - mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a; - stride = i == 1 ? s : 1)) + mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, + activation; stride)) inplanes = outplanes end end diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 78c55e144..82265e125 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -1,7 +1,7 @@ """ mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv3 model. ([reference](https://arxiv.org/abs/1905.02244)). @@ -19,12 +19,14 @@ Create a MobileNetv3 model. - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) - - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, dropout_rate = 0.2, + max_width::Integer = 1024, reduced_tail::Bool = false, + tail_dilated::Bool = false, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) @@ -32,25 +34,34 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) explanes = 0 + nstages = length(configs) + reduced_divider = 1 # building inverted residual blocks - for (k, t, c, reduction, activation, stride) in configs + for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs) + dilation = 1 + if nstages - i <= 2 + if reduced_tail + reduced_divider = 2 + c /= reduced_divider + end + if tail_dilated + dilation = 2 + end + end # inverted residual layers outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, mbconv((k, k), inplanes, explanes, outplanes, activation; - stride, reduction)) + stride, reduction, dilation)) inplanes = outplanes end # building last layers - headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : - max_width + headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8) append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) - classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(explanes, headplanes, hardswish), - Dropout(dropout_rate), - Dense(headplanes, nclasses)) - return Chain(Chain(layers...), classifier) + return Chain(Chain(layers...), + create_classifier(explanes, headplanes, nclasses, + (hardswish, identity); dropout_rate)) end # Layer configurations for small and large models for MobileNetv3 diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 35bb34fc4..458481d73 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -27,12 +27,11 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor - outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, stride, pad = 1) - conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, + conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm, pad = 1) - layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), + layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes), drop_path] return Chain(filter!(!=(identity), layers)...) end @@ -201,7 +200,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; expansion::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) @@ -236,7 +235,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; expansion::Integer = 4, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) @@ -295,8 +294,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) + use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, + dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) # Block builder @@ -319,8 +318,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0" - @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0" + @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" + @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, @@ -347,7 +346,7 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) - +# larger ResNet-like models const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 72ace2c2c..45615df5e 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -28,7 +28,10 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mlp.jl") -export mlp_block, gated_mlp_block, create_fc, create_classifier +export mlp_block, gated_mlp_block + +include("classifier.jl") +export create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl new file mode 100644 index 000000000..bebdc4099 --- /dev/null +++ b/src/layers/classifier.jl @@ -0,0 +1,93 @@ +""" + create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + +Creates a classifier head to be used for models. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `activation`: activation function to use + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # Dropout is applied after the pooling layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # Fully-connected layer + if use_conv + push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) + else + push!(classifier, Dense(inplanes => nclasses, activation)) + end + return Chain(classifier...) +end + +""" + create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + +Creates a classifier head to be used for models with an extra hidden layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `hidden_planes`: number of hidden feature maps + - `nclasses`: number of output classes + - `activations`: activation functions to use for the hidden and output layers. This is a + tuple of two elements, the first being the activation function for the hidden layer and the + second for the output layer. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. This + is a tuple of two booleans, the first for the hidden layer and the second for the output + layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2, Any} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + fc_layers = [uc ? Conv$(1, 1) : Dense for uc in use_conv] + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv[1] && pool_layer !== identity + if use_conv[1] + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv[1]` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # first fully-connected layer + if !isnothing(hidden_planes) + push!(classifier, fc_layers[1](inplanes => hidden_planes, activations[1])) + end + # Dropout is applied after the first dense layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # second fully-connected layer + push!(classifier, fc_layers[2](hidden_planes => nclasses, activations[2])) + return Chain(classifier...) +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c94ceb045..bb39a0e07 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -146,9 +146,9 @@ Create a basic inverted residual block for MobileNet variants - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)) """ -function mbconv(kernel_size::Dims{2}, inplanes::Integer, - explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, reduction::Union{Nothing, Integer} = nothing, +function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] @@ -158,9 +158,10 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end # depthwise + stride = dilation > 1 ? 1 : stride append!(layers, conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, pad = SamePad(), groups = explanes)) + stride, dilation, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) push!(layers, diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b252584fe..387b562ef 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -20,7 +20,8 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. - `x`: input array - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -56,11 +57,25 @@ dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) The `DropBlock` layer. While training, it zeroes out continguous regions of size `block_size` in the input. During inference, it simply returns the input `x`. +It can be used in two ways: either with all blocks having the same survival probability +or with a linear scaling rule across the blocks. This is performed only at training time. +At test time, the `DropBlock` layer is equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropBlock` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. + ((reference)[https://arxiv.org/abs/1810.12890]) # Arguments - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -90,11 +105,8 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_s function (m::DropBlock)(x) _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) - if Flux._isactive(m) - return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) - else - return x - end + return Flux._isactive(m) ? + dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) : x end function Flux.testmode!(m::DropBlock, mode = true) @@ -103,7 +115,7 @@ end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, rng = rng_from_array()) - if drop_block_prob == 0.0 + if isnothing(drop_block_prob) return identity end return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) @@ -120,8 +132,8 @@ end """ DropPath(p; [rng = rng_from_array(x)]) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and -`identity` otherwise. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 ≤ p ≤ 1` and +`identity` if p is `nothing`. ([reference](https://arxiv.org/abs/1603.09382)) This layer can be used to drop certain blocks in a residual structure and allow them to @@ -134,10 +146,10 @@ equivalent to `identity`. In the case of the linear scaling rule, the calculations of survival probabilities for each block may lead to a survival probability > 1 for a given block. This will lead to - `DropPath` returning `identity`, which may not be desirable. This usually happens with - a low number of blocks and a high base survival probability, so it is recommended to - use a fixed base survival probability across blocks. If this is not possible, then - a lower base survival probability is recommended. + `DropPath` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. # Arguments @@ -146,4 +158,6 @@ equivalent to `identity`. for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU. """ -DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity +function DropPath(p; rng = rng_from_array()) + return isnothing(p) ? identity : Dropout(p; dims = 4, rng) +end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 467df30a4..e6336de9c 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,3 +1,4 @@ +# TODO @theabhirath figure out consistent behaviour for dropout rates - 0.0 vs `nothing` """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; dropout_rate = 0., activation = gelu) @@ -45,46 +46,3 @@ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) - -""" - create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = 0.0, use_conv::Bool = false) - -Creates a classifier head to be used for models. - -# Arguments - - - `inplanes`: number of input feature maps - - `nclasses`: number of output classes - - `activation`: activation function to use - - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with - any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. - - `dropout_rate`: dropout rate used in the classifier head. - - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. -""" -function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = nothing) - # Decide whether to flatten the input or not - flatten_in_pool = !use_conv && pool_layer !== identity - if use_conv - @assert pool_layer === identity - "`pool_layer` must be identity if `use_conv` is true" - end - classifier = [] - if flatten_in_pool - push!(classifier, pool_layer, MLUtils.flatten) - else - push!(classifier, pool_layer) - end - # Dropout is applied after the pooling layer - isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) - # Fully-connected layer - if use_conv - push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) - else - push!(classifier, Dense(inplanes => nclasses, activation)) - end - return Chain(classifier...) -end diff --git a/src/utilities.jl b/src/utilities.jl index 4a611b5a2..09074b0e8 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -66,13 +66,17 @@ function _maybe_big_show(io, model) end """ - linear_scheduler(drop_path_rate = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_rate = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_rate::Nothing; depth::Integer) -Returns the dropout rates for a given depth using the linear scaling rule. +Returns the dropout rates for a given depth using the linear scaling rule. If the +`drop_rate` is `nothing`, it returns a `Vector` of length `depth` with all values +equal to `nothing`. """ function linear_scheduler(drop_rate = 0.0; depth::Integer, start_value = 0.0) return LinRange(start_value, drop_rate, depth) end +linear_scheduler(drop_rate::Nothing; depth::Integer) = fill(drop_rate, depth) # Utility function for depth and configuration checks in models function _checkconfig(config, configs) From 593752f943e453692b49bdb2ddd4ab9b0cc1c44a Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Mon, 15 Aug 2022 17:11:33 +0530 Subject: [PATCH 25/34] The real hero was `block_idx` all along --- src/convnets/efficientnets/core.jl | 101 ++++++++----------- src/convnets/efficientnets/efficientnet.jl | 8 +- src/convnets/efficientnets/efficientnetv2.jl | 60 +++++------ src/convnets/resnets/core.jl | 56 +++++----- src/convnets/resnets/res2net.jl | 1 - src/layers/Layers.jl | 5 +- src/layers/conv.jl | 68 ------------- src/layers/mbconv.jl | 65 ++++++++++++ 8 files changed, 176 insertions(+), 188 deletions(-) create mode 100644 src/layers/mbconv.jl diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 7a221c0e4..139252a3b 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,78 +1,61 @@ -abstract type _MBConfig end - -struct MBConvConfig <: _MBConfig - kernel_size::Dims{2} - inplanes::Integer - outplanes::Integer - expansion::Real - stride::Integer - nrepeats::Integer -end -function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Real, stride::Integer, nrepeats::Integer, - width_mult::Real = 1, depth_mult::Real = 1) +function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, + stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) + depth_mult, width_mult = scalings + k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) - nrepeats = ceil(Int, nrepeats * depth_mult) - return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, - stride, nrepeats) + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = mbconv((k, k), inplanes, explanes, outplanes, swish; norm_layer, + stride, reduction = 4) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, ceil(Int, nrepeats * depth_mult) end -function efficientnetblock(m::MBConvConfig, norm_layer) - layers = [] - explanes = _round_channels(m.inplanes * m.expansion, 8) - push!(layers, - mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer, - stride = m.stride, reduction = 4)) - explanes = _round_channels(m.outplanes * m.expansion, 8) - append!(layers, - [mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer, - stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)]) - return Chain(layers...) +function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, + stage_idx::Integer; norm_layer = BatchNorm) + k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = fused_mbconv((k, k), inplanes, explanes, outplanes, swish; + norm_layer, stride) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats end -struct FusedMBConvConfig <: _MBConfig - kernel_size::Dims{2} - inplanes::Integer - outplanes::Integer - expansion::Real - stride::Integer - nrepeats::Integer -end -function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Real, stride::Integer, nrepeats::Integer) - return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, - stride, nrepeats) -end - -function efficientnetblock(m::FusedMBConvConfig, norm_layer) - layers = [] - explanes = _round_channels(m.inplanes * m.expansion, 8) - push!(layers, - fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; - norm_layer, stride = m.stride)) - explanes = _round_channels(m.outplanes * m.expansion, 8) - append!(layers, - [fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; - norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)]) - return Chain(layers...) +function efficientnet_builder(block_configs::AbstractVector{NTuple{6, Int}}, + residual_fns::AbstractVector; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) + bxs = [residual_fn(block_configs, stage_idx; scalings, norm_layer) + for (stage_idx, residual_fn) in enumerate(residual_fns)] + return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end -function efficientnet(block_configs::AbstractVector{<:_MBConfig}; - headplanes::Union{Nothing, Integer} = nothing, +function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, + residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), + headplanes::Integer = _round_channels(block_configs[end][3] * + scalings[2], 8) * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer, + conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks - append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs]) + get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; + scalings, norm_layer) + append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers - outplanes = block_configs[end].outplanes - headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes append!(layers, - conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad())) + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8), + headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 0bb481dda..bff9d8dde 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -9,7 +9,6 @@ const EFFICIENTNET_BLOCK_CONFIGS = [ (5, 112, 192, 6, 2, 4), (3, 192, 320, 6, 1, 1), ] - # Data is organised as (r, (w, d)) # r: image resolution # w: width scaling @@ -44,9 +43,10 @@ end function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - cfg_fn = (args...) -> MBConvConfig(args..., EFFICIENTNET_GLOBAL_CONFIGS[config][2]...) - block_configs = [cfg_fn(args...) for args in EFFICIENTNET_BLOCK_CONFIGS] - layers = efficientnet(block_configs; inchannels, nclasses) + scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] + layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS, + fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS)); + scalings, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index 6847e7fa8..ecbeed07a 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,35 +1,36 @@ # block configs for EfficientNetv2 +# data organised as (k, i, o, e, s, n) const EFFNETV2_CONFIGS = Dict(:small => [ - FusedMBConvConfig(3, 24, 24, 1, 1, 2), - FusedMBConvConfig(3, 24, 48, 4, 2, 4), - FusedMBConvConfig(3, 48, 64, 4, 2, 4), - MBConvConfig(3, 64, 128, 4, 2, 6), - MBConvConfig(3, 128, 160, 6, 1, 9), - MBConvConfig(3, 160, 256, 6, 2, 15)], + (3, 24, 24, 1, 1, 2), + (3, 24, 48, 4, 2, 4), + (3, 48, 64, 4, 2, 4), + (3, 64, 128, 4, 2, 6), + (3, 128, 160, 6, 1, 9), + (3, 160, 256, 6, 2, 15)], :medium => [ - FusedMBConvConfig(3, 24, 24, 1, 1, 3), - FusedMBConvConfig(3, 24, 48, 4, 2, 5), - FusedMBConvConfig(3, 48, 80, 4, 2, 5), - MBConvConfig(3, 80, 160, 4, 2, 7), - MBConvConfig(3, 160, 176, 6, 1, 14), - MBConvConfig(3, 176, 304, 6, 2, 18), - MBConvConfig(3, 304, 512, 6, 1, 5)], + (3, 24, 24, 1, 1, 3), + (3, 24, 48, 4, 2, 5), + (3, 48, 80, 4, 2, 5), + (3, 80, 160, 4, 2, 7), + (3, 160, 176, 6, 1, 14), + (3, 176, 304, 6, 2, 18), + (3, 304, 512, 6, 1, 5)], :large => [ - FusedMBConvConfig(3, 32, 32, 1, 1, 4), - FusedMBConvConfig(3, 32, 64, 4, 2, 7), - FusedMBConvConfig(3, 64, 96, 4, 2, 7), - MBConvConfig(3, 96, 192, 4, 2, 10), - MBConvConfig(3, 192, 224, 6, 1, 19), - MBConvConfig(3, 224, 384, 6, 2, 25), - MBConvConfig(3, 384, 640, 6, 1, 7)], + (3, 32, 32, 1, 1, 4), + (3, 32, 64, 4, 2, 7), + (3, 64, 96, 4, 2, 7), + (3, 96, 192, 4, 2, 10), + (3, 192, 224, 6, 1, 19), + (3, 224, 384, 6, 2, 25), + (3, 384, 640, 6, 1, 7)], :xlarge => [ - FusedMBConvConfig(3, 32, 32, 1, 1, 4), - FusedMBConvConfig(3, 32, 64, 4, 2, 8), - FusedMBConvConfig(3, 64, 96, 4, 2, 8), - MBConvConfig(3, 96, 192, 4, 2, 16), - MBConvConfig(3, 192, 224, 6, 1, 24), - MBConvConfig(3, 384, 512, 6, 2, 32), - MBConvConfig(3, 512, 768, 6, 1, 8)]) + (3, 32, 32, 1, 1, 4), + (3, 32, 64, 4, 2, 8), + (3, 64, 96, 4, 2, 8), + (3, 96, 192, 4, 2, 16), + (3, 192, 224, 6, 1, 24), + (3, 384, 512, 6, 2, 32), + (3, 512, 768, 6, 1, 8)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, @@ -54,8 +55,9 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnet(EFFNETV2_CONFIGS[config]; headplanes = 1280, inchannels, - nclasses) + layers = efficientnet(EFFNETV2_CONFIGS[config], + vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4)); + headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 458481d73..39e283fd0 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -203,6 +203,8 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # Also get `planes_vec` needed for block `inplanes` and `planes` calculations pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) planes_vec = collect(planes_fn(block_repeats)) @@ -265,22 +267,26 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; return get_layers end +# TODO @theabhirath figure out a better name and potentially refactor other CNNs to use this function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) # Construct each stage stages = [] - for (stage_idx, num_blocks) in enumerate(block_repeats) + for (stage_idx, nblocks) in enumerate(block_repeats) # Construct the blocks for each stage - blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...) - for block_idx in 1:num_blocks] + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + return (length(branches) == 1) ? only(branches) : + Parallel(connection, branches...) + end push!(stages, Chain(blocks...)) end return Chain(stages...) end -function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, +function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection) + stage_blocks = resnet_stages(builders, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -302,39 +308,37 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" - get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, - activation, norm_layer, revnorm, attn_fn, - drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, - planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + builder = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, revnorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck - get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, - reduction_factor, activation, norm_layer, revnorm, - attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, - planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + builder = bottleneck_builder(block_repeats; inplanes, cardinality, + base_width, reduction_factor, activation, norm_layer, + revnorm, attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" - @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" - @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" + @assert isnothing(drop_block_rate) + "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" + @assert isnothing(drop_path_rate) + "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" + @assert reduction_factor == 1 + "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + downsample_tuple = downsample_opt, kwargs...) else # TODO: write better message when we have link to dev docs for resnet throw(ArgumentError("Unknown block type $block_type")) end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, get_layers, block_repeats, - connection$activation, classifier_fn) + return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)), + block_repeats, connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index e308e1125..08f2c87ee 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -54,7 +54,6 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; attn_fn = planes -> identity, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) - planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) # This is needed for block `inplanes` and `planes` calculations diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 45615df5e..9bdf1f913 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -19,7 +19,7 @@ include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, basic_conv_bn, dwsep_conv_bn, mbconv, fused_mbconv +export conv_norm, basic_conv_bn, dwsep_conv_bn include("drop.jl") export DropBlock, DropPath @@ -27,6 +27,9 @@ export DropBlock, DropPath include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens +include("mbconv.jl") +export mbconv, fused_mbconv + include("mlp.jl") export mlp_block, gated_mlp_block diff --git a/src/layers/conv.jl b/src/layers/conv.jl index bb39a0e07..d81d76c9c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -125,71 +125,3 @@ function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, conv_norm((1, 1), inplanes, outplanes, activation; eps, revnorm, use_norm = use_norm[2], bias = bias[2])) end - -# TODO add support for stochastic depth to mbconv and fused_mbconv -""" - mbconv(kernel_size, inplanes::Integer, explanes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing) - -Create a basic inverted residual block for MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `kernel_size`: kernel size of the convolutional layers - - `inplanes`: number of input feature maps - - `explanes`: The number of feature maps in the hidden layer - - `outplanes`: The number of output feature maps - - `activation`: The activation function for the first two convolution layer - - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)) -""" -function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, - norm_layer = BatchNorm) - @assert stride in [1, 2] "`stride` has to be 1 or 2" - layers = [] - # expand - if inplanes != explanes - append!(layers, - conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) - end - # depthwise - stride = dilation > 1 ? 1 : stride - append!(layers, - conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, dilation, pad = SamePad(), groups = explanes)) - # squeeze-excite layer - if !isnothing(reduction) - push!(layers, - squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, - gate_activation = hardσ)) - end - # project - append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) - return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : - Chain(layers...) -end - -function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, - explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, norm_layer = BatchNorm) - @assert stride in [1, 2] "`stride` has to be 1 or 2" - layers = [] - if explanes != inplanes - # fused expand - append!(layers, - conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, - pad = SamePad())) - # project - append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) - else - append!(layers, - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) - end - return stride == 1 && inplanes == outplanes ? SkipConnection(Chain(layers...), +) : - Chain(layers...) -end diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl new file mode 100644 index 000000000..af17505ea --- /dev/null +++ b/src/layers/mbconv.jl @@ -0,0 +1,65 @@ +# TODO add support for stochastic depth to mbconv and fused_mbconv +""" + mbconv(kernel_size, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Integer} = nothing) + +Create a basic inverted residual block for MobileNet variants +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps + - `explanes`: The number of feature maps in the hidden layer + - `outplanes`: The number of output feature maps + - `activation`: The activation function for the first two convolution layer + - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 + - `reduction`: The reduction factor for the number of hidden feature maps + in a squeeze and excite layer (see [`squeeze_excite`](#)) +""" +function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + layers = [] + # expand + if inplanes != explanes + append!(layers, + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) + end + # depthwise + stride = dilation > 1 ? 1 : stride + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, dilation, pad = SamePad(), groups = explanes)) + # squeeze-excite layer + if !isnothing(reduction) + push!(layers, + squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, + gate_activation = hardσ)) + end + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) + return Chain(layers...) +end + +function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `fused_mbconv`" + layers = [] + if explanes != inplanes + # fused expand + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) + else + append!(layers, + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) + end + return Chain(layers...) +end From 55bc54466da7638b7476ba3ee4ab23acb7146ba8 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Mon, 15 Aug 2022 22:51:42 +0530 Subject: [PATCH 26/34] Fix minor hiccups --- src/convnets/efficientnets/core.jl | 15 ++++----- src/convnets/efficientnets/efficientnetv2.jl | 3 +- src/convnets/resnets/core.jl | 32 +++++++++++--------- src/convnets/resnets/res2net.jl | 1 + 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 139252a3b..22ddf8172 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,7 +1,7 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - depth_mult, width_mult = scalings + width_mult, depth_mult = scalings k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) @@ -17,7 +17,8 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, end function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; norm_layer = BatchNorm) + stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes @@ -40,22 +41,22 @@ end function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), - headplanes::Integer = _round_channels(block_configs[end][3] * - scalings[2], 8) * 4, + headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer, - stride = 2, pad = SamePad())) + conv_norm((3, 3), inchannels, + _round_channels(block_configs[1][2] * scalings[1], 8), swish; + norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; scalings, norm_layer) append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index ecbeed07a..d9a9d0d77 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -56,7 +56,8 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) layers = efficientnet(EFFNETV2_CONFIGS[config], - vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4)); + vcat(fill(fused_mbconv_builder, 3), + fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 39e283fd0..95291bc6f 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -275,7 +275,7 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con # Construct the blocks for each stage blocks = map(1:nblocks) do block_idx branches = get_layers(stage_idx, block_idx) - return (length(branches) == 1) ? only(branches) : + return length(branches) == 1 ? only(branches) : Parallel(connection, branches...) end push!(stages, Chain(blocks...)) @@ -283,10 +283,10 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con return Chain(stages...) end -function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer}, +function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(builders, block_repeats, connection) + stage_blocks = resnet_stages(get_layers, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -308,17 +308,19 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" - builder = basicblock_builder(block_repeats; inplanes, reduction_factor, - activation, norm_layer, revnorm, attn_fn, - drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, kwargs...) + get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, revnorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck - builder = bottleneck_builder(block_repeats; inplanes, cardinality, - base_width, reduction_factor, activation, norm_layer, - revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, kwargs...) + get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" @@ -337,8 +339,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)), - block_repeats, connection$activation, classifier_fn) + return resnet((imsize..., inchannels), stem, get_layers, block_repeats, + connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 08f2c87ee..e308e1125 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -54,6 +54,7 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; attn_fn = planes -> identity, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) + planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) # This is needed for block `inplanes` and `planes` calculations From 3d63f7263442a5ff8a93cdb71260297f7ce1b1c6 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 16 Aug 2022 12:01:37 +0530 Subject: [PATCH 27/34] Moving closer to the one true function --- src/convnets/efficientnets/core.jl | 52 +++++++++-------- src/convnets/efficientnets/efficientnet.jl | 18 +++--- src/convnets/efficientnets/efficientnetv2.jl | 61 ++++++++++---------- src/convnets/mobilenets/mobilenetv2.jl | 43 +++++++------- src/layers/mbconv.jl | 3 +- 5 files changed, 90 insertions(+), 87 deletions(-) diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 22ddf8172..ee94ff270 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,62 +1,66 @@ -function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) +function mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, + round_fn = planes -> _round_channels(planes, 8)) width_mult, depth_mult = scalings - k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] - inplanes = _round_channels(inplanes * width_mult, 8) + k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + inplanes = round_fn(inplanes * width_mult) outplanes = _round_channels(outplanes * width_mult, 8) function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = mbconv((k, k), inplanes, explanes, outplanes, swish; norm_layer, - stride, reduction = 4) + block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, ceil(Int, nrepeats * depth_mult) end -function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) - k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] +function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) + k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = fused_mbconv((k, k), inplanes, explanes, outplanes, swish; + block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, stride) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, nrepeats end -function efficientnet_builder(block_configs::AbstractVector{NTuple{6, Int}}, - residual_fns::AbstractVector; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - bxs = [residual_fn(block_configs, stage_idx; scalings, norm_layer) +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, + residual_fns::AbstractVector; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) + bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) for (stage_idx, residual_fn) in enumerate(residual_fns)] return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end -function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, - residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), +function efficientnet(block_configs::AbstractVector{<:Tuple}, + residual_fns::AbstractVector; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, - _round_channels(block_configs[1][2] * scalings[1], 8), swish; - norm_layer, stride = 2, pad = SamePad())) + conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8), + swish; norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks - get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; - scalings, norm_layer) + get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns; + inplanes, scalings, norm_layer) append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), + conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index bff9d8dde..495b8658f 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -1,13 +1,13 @@ # block configs for EfficientNet const EFFICIENTNET_BLOCK_CONFIGS = [ - # k, i, o, e, s, n - (3, 32, 16, 1, 1, 1), - (3, 16, 24, 6, 2, 2), - (5, 24, 40, 6, 2, 2), - (3, 40, 80, 6, 2, 3), - (5, 80, 112, 6, 1, 3), - (5, 112, 192, 6, 2, 4), - (3, 192, 320, 6, 1, 1), + # k, c, e, s, n, r, a + (3, 16, 1, 1, 1, 4, swish), + (3, 24, 6, 2, 2, 4, swish), + (5, 40, 6, 2, 2, 4, swish), + (3, 80, 6, 2, 3, 4, swish), + (5, 112, 6, 1, 3, 4, swish), + (5, 192, 6, 2, 4, 4, swish), + (3, 320, 6, 1, 1, 4, swish), ] # Data is organised as (r, (w, d)) # r: image resolution @@ -46,7 +46,7 @@ function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Intege scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS, fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS)); - scalings, inchannels, nclasses) + inplanes = 32, scalings, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index d9a9d0d77..69f67cc81 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,36 +1,36 @@ # block configs for EfficientNetv2 -# data organised as (k, i, o, e, s, n) +# data organised as (k, c, e, s, n, r, a) const EFFNETV2_CONFIGS = Dict(:small => [ - (3, 24, 24, 1, 1, 2), - (3, 24, 48, 4, 2, 4), - (3, 48, 64, 4, 2, 4), - (3, 64, 128, 4, 2, 6), - (3, 128, 160, 6, 1, 9), - (3, 160, 256, 6, 2, 15)], + (3, 24, 1, 1, 2, nothing, swish), + (3, 48, 4, 2, 4, nothing, swish), + (3, 64, 4, 2, 4, nothing, swish), + (3, 128, 4, 2, 6, 4, swish), + (3, 160, 6, 1, 9, 4, swish), + (3, 256, 6, 2, 15, 4, swish)], :medium => [ - (3, 24, 24, 1, 1, 3), - (3, 24, 48, 4, 2, 5), - (3, 48, 80, 4, 2, 5), - (3, 80, 160, 4, 2, 7), - (3, 160, 176, 6, 1, 14), - (3, 176, 304, 6, 2, 18), - (3, 304, 512, 6, 1, 5)], + (3, 24, 1, 1, 3, nothing, swish), + (3, 48, 4, 2, 5, nothing, swish), + (3, 80, 4, 2, 5, nothing, swish), + (3, 160, 4, 2, 7, 4, swish), + (3, 176, 6, 1, 14, 4, swish), + (3, 304, 6, 2, 18, 4, swish), + (3, 512, 6, 1, 5, 4, swish)], :large => [ - (3, 32, 32, 1, 1, 4), - (3, 32, 64, 4, 2, 7), - (3, 64, 96, 4, 2, 7), - (3, 96, 192, 4, 2, 10), - (3, 192, 224, 6, 1, 19), - (3, 224, 384, 6, 2, 25), - (3, 384, 640, 6, 1, 7)], + (3, 32, 1, 1, 4, nothing, swish), + (3, 64, 4, 2, 7, nothing, swish), + (3, 96, 4, 2, 7, nothing, swish), + (3, 192, 4, 2, 10, 4, swish), + (3, 224, 6, 1, 19, 4, swish), + (3, 384, 6, 2, 25, 4, swish), + (3, 640, 6, 1, 7, 4, swish)], :xlarge => [ - (3, 32, 32, 1, 1, 4), - (3, 32, 64, 4, 2, 8), - (3, 64, 96, 4, 2, 8), - (3, 96, 192, 4, 2, 16), - (3, 192, 224, 6, 1, 24), - (3, 384, 512, 6, 2, 32), - (3, 512, 768, 6, 1, 8)]) + (3, 32, 1, 1, 4, nothing, swish), + (3, 64, 4, 2, 8, nothing, swish), + (3, 96, 4, 2, 8, nothing, swish), + (3, 192, 4, 2, 16, 4, swish), + (3, 384, 6, 1, 24, 4, swish), + (3, 512, 6, 2, 32, 4, swish), + (3, 768, 6, 1, 8, 4, swish)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, @@ -58,9 +58,10 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false, layers = efficientnet(EFFNETV2_CONFIGS[config], vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); - headplanes = 1280, inchannels, nclasses) + inplanes = EFFNETV2_CONFIGS[config][1][2], headplanes = 1280, + inchannels, nclasses) if pretrain - loadpretrain!(layers, string("efficientnetv2")) + loadpretrain!(layers, string("efficientnetv2-", config)) end return EfficientNetv2(layers) end diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index d81256968..16b59c3a8 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -14,7 +14,6 @@ Create a MobileNetv2 model. + `c`: The number of output feature maps + `n`: The number of times a block is repeated + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper) @@ -24,41 +23,39 @@ Create a MobileNetv2 model. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ -function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, +function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1280, divisor::Integer = 8, + inplanes::Integer = 32, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer - inplanes = _round_channels(32 * width_mult, divisor) + inplanes = _round_channels(inplanes * width_mult, divisor) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - for (t, c, n, s, activation) in configs - outplanes = _round_channels(c * width_mult, divisor) - for i in 1:n - stride = i == 1 ? s : 1 - push!(layers, - mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, - activation; stride)) - inplanes = outplanes - end - end + get_layers, block_repeats = mbconv_stack_builder(block_configs, + fill(mbconv_builder, + length(block_configs)); + inplanes) + append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6)) + append!(layers, + conv_norm((1, 1), _round_channels(block_configs[end][2], 8), + outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 const MOBILENETV2_CONFIGS = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), + # k, c, e, s, n, r, a + (3, 16, 1, 1, 1, nothing, relu6), + (3, 24, 6, 2, 2, nothing, relu6), + (3, 32, 6, 2, 3, nothing, relu6), + (3, 64, 6, 2, 4, nothing, relu6), + (3, 96, 6, 1, 3, nothing, relu6), + (3, 160, 6, 2, 3, nothing, relu6), + (3, 320, 6, 1, 1, nothing, relu6), ] """ diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index af17505ea..eeac44dbd 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -59,7 +59,8 @@ function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) else append!(layers, - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride)) + conv_norm(kernel_size, inplanes, outplanes, activation; pad = SamePad(), + norm_layer, stride)) end return Chain(layers...) end From 818c5847aed765f34fedbbbbced9770efb33ebd1 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 16 Aug 2022 15:06:29 +0530 Subject: [PATCH 28/34] Some more reorganisation --- src/Metalhead.jl | 5 ++ src/convnets/builders/core.jl | 15 +++++ src/convnets/builders/mbconv.jl | 44 ++++++++++++ src/convnets/builders/resblocks.jl | 71 ++++++++++++++++++++ src/convnets/efficientnets/core.jl | 47 +------------ src/convnets/mobilenets/mobilenetv2.jl | 2 +- src/convnets/resnets/core.jl | 93 +------------------------- 7 files changed, 140 insertions(+), 137 deletions(-) create mode 100644 src/convnets/builders/core.jl create mode 100644 src/convnets/builders/mbconv.jl create mode 100644 src/convnets/builders/resblocks.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 6b0179f45..8c8800e84 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -19,6 +19,11 @@ include("layers/Layers.jl") using .Layers # CNN models +## Builders +include("convnets/builders/core.jl") +include("convnets/builders/mbconv.jl") +include("convnets/builders/resblocks.jl") +## AlexNet and VGG include("convnets/alexnet.jl") include("convnets/vgg.jl") ## ResNets diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl new file mode 100644 index 000000000..413c78c27 --- /dev/null +++ b/src/convnets/builders/core.jl @@ -0,0 +1,15 @@ +# TODO potentially refactor other CNNs to use this +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) + # Construct each stage + stages = [] + for (stage_idx, nblocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + return length(branches) == 1 ? only(branches) : + Parallel(connection, branches...) + end + push!(stages, Chain(blocks...)) + end + return Chain(stages...) +end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl new file mode 100644 index 000000000..3312c1cfe --- /dev/null +++ b/src/convnets/builders/mbconv.jl @@ -0,0 +1,44 @@ +function mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, + round_fn = planes -> _round_channels(planes, 8)) + width_mult, depth_mult = scalings + k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + inplanes = round_fn(inplanes * width_mult) + outplanes = _round_channels(outplanes * width_mult, 8) + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, ceil(Int, nrepeats * depth_mult) +end + +function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) + k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + function get_layers(block_idx) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; + norm_layer, stride) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats +end + +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, + residual_fns::AbstractVector; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) + bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) + for (stage_idx, residual_fn) in enumerate(residual_fns)] + return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) +end diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl new file mode 100644 index 000000000..2bf03746f --- /dev/null +++ b/src/convnets/builders/resblocks.jl @@ -0,0 +1,71 @@ +function basicblock_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 1, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + drop_block_rate = nothing, drop_path_rate = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # Also get `planes_vec` needed for block `inplanes` and `planes` calculations + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = stride != 1 || inplanes != planes * expansion ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = basicblock(inplanes, planes; stride, reduction_factor, activation, + norm_layer, revnorm, attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end + +function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + drop_block_rate = nothing, drop_path_rate = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = stride != 1 || inplanes != planes * expansion ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = bottleneck(inplanes, planes; stride, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end \ No newline at end of file diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index ee94ff270..c37fb8c93 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,48 +1,3 @@ -function mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, - round_fn = planes -> _round_channels(planes, 8)) - width_mult, depth_mult = scalings - k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] - inplanes = round_fn(inplanes * width_mult) - outplanes = _round_channels(outplanes * width_mult, 8) - function get_layers(block_idx) - inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) - stride = block_idx == 1 ? stride : 1 - block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, - stride, reduction) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) - end - return get_layers, ceil(Int, nrepeats * depth_mult) -end - -function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] - function get_layers(block_idx) - inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) - stride = block_idx == 1 ? stride : 1 - block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; - norm_layer, stride) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) - end - return get_layers, nrepeats -end - -function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, - residual_fns::AbstractVector; inplanes::Integer, - scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) - bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) - for (stage_idx, residual_fn) in enumerate(residual_fns)] - return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) -end - function efficientnet(block_configs::AbstractVector{<:Tuple}, residual_fns::AbstractVector; inplanes::Integer, scalings::NTuple{2, Real} = (1, 1), @@ -57,7 +12,7 @@ function efficientnet(block_configs::AbstractVector{<:Tuple}, # building inverted residual blocks get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns; inplanes, scalings, norm_layer) - append!(layers, resnet_stages(get_layers, block_repeats, +)) + append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers append!(layers, conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8), diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 16b59c3a8..ad41ff967 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -37,7 +37,7 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = fill(mbconv_builder, length(block_configs)); inplanes) - append!(layers, resnet_stages(get_layers, block_repeats, +)) + append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 95291bc6f..83ab9da04 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -195,98 +195,10 @@ function resnet_planes(block_repeats::AbstractVector{<:Integer}) for (stage_idx, stages) in enumerate(block_repeats)) end -function basicblock_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, reduction_factor::Integer = 1, - expansion::Integer = 1, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = nothing, drop_path_rate = nothing, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # Also get `planes_vec` needed for block `inplanes` and `planes` calculations - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # This is also needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = stride != 1 || inplanes != planes * expansion ? - downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - block = basicblock(inplanes, planes; stride, reduction_factor, activation, - norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - -function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, cardinality::Integer = 1, - base_width::Integer = 64, reduction_factor::Integer = 1, - expansion::Integer = 4, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = nothing, drop_path_rate = nothing, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # This is also needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = stride != 1 || inplanes != planes * expansion ? - downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - block = bottleneck(inplanes, planes; stride, cardinality, base_width, - reduction_factor, activation, norm_layer, revnorm, - attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - -# TODO @theabhirath figure out a better name and potentially refactor other CNNs to use this -function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) - # Construct each stage - stages = [] - for (stage_idx, nblocks) in enumerate(block_repeats) - # Construct the blocks for each stage - blocks = map(1:nblocks) do block_idx - branches = get_layers(stage_idx, block_idx) - return length(branches) == 1 ? only(branches) : - Parallel(connection, branches...) - end - push!(stages, Chain(blocks...)) - end - return Chain(stages...) -end - function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection) + stage_blocks = cnn_stages(get_layers, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -352,7 +264,8 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) -# larger ResNet-like models +# block configurations for larger ResNet-like models that do not use +# depths 18 and 34 const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) From 8092818f0bfff0b8c33f0887c06bc0d75dca034c Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 18 Aug 2022 14:28:40 +0530 Subject: [PATCH 29/34] Huge refactor of MobileNet and EfficientNet families The rise of the builders --- src/convnets/builders/core.jl | 18 +++- src/convnets/builders/mbconv.jl | 108 +++++++++++++++---- src/convnets/builders/resblocks.jl | 2 +- src/convnets/efficientnets/core.jl | 9 +- src/convnets/efficientnets/efficientnet.jl | 29 +++-- src/convnets/efficientnets/efficientnetv2.jl | 75 ++++++------- src/convnets/mobilenets/mobilenetv1.jl | 41 ++++--- src/convnets/mobilenets/mobilenetv2.jl | 23 ++-- src/convnets/mobilenets/mobilenetv3.jl | 91 +++++++--------- src/convnets/resnets/core.jl | 29 ++--- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 48 --------- src/layers/drop.jl | 6 +- src/layers/mbconv.jl | 108 ++++++++++++++++--- src/layers/selayers.jl | 10 +- 15 files changed, 347 insertions(+), 252 deletions(-) diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl index 413c78c27..e02092eca 100644 --- a/src/convnets/builders/core.jl +++ b/src/convnets/builders/core.jl @@ -11,5 +11,21 @@ function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connec end push!(stages, Chain(blocks...)) end - return Chain(stages...) + return stages +end + +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}) + # Construct each stage + stages = [] + for (stage_idx, nblocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + @assert length(branches)==1 "get_layers should return a single branch for each + block if no connection is specified" + return only(branches) + end + push!(stages, Chain(blocks...)) + end + return stages end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index 3312c1cfe..c90b8215c 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -1,44 +1,108 @@ -function mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm, - round_fn = planes -> _round_channels(planes, 8)) +function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + width_mult::Number; norm_layer = BatchNorm, kwargs...) + _, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] + outplanes = floor(Int, outplanes * width_mult) + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + stride = block_idx == 1 ? stride : 1 + block = Chain(dwsep_conv_bn((k, k), inplanes, outplanes, activation; + stride, pad = SamePad(), norm_layer, kwargs...)...) + return (block,) + end + return get_layers, nrepeats +end + +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + scalings::NTuple{2, Real}; norm_layer = BatchNorm, + round_fn = planes -> _round_channels(planes, 8), kwargs...) width_mult, depth_mult = scalings - k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] + block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] inplanes = round_fn(inplanes * width_mult) outplanes = _round_channels(outplanes * width_mult, 8) - function get_layers(block_idx) + function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer, - stride, reduction) + block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction, no_skip = true, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, ceil(Int, nrepeats * depth_mult) end -function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2] - function get_layers(block_idx) +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + width_mult::Real; norm_layer = BatchNorm, kwargs...) + block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + inplanes = _round_channels(inplanes * width_mult, 8) + outplanes = _round_channels(outplanes * width_mult, 8) + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction, no_skip = true, kwargs...) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats +end + +function fused_mbconv_builder(block_configs, inplanes::Integer, + stage_idx::Integer; norm_layer = BatchNorm, kwargs...) + _, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; - norm_layer, stride) + norm_layer, stride, no_skip = true, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, nrepeats end -function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, - residual_fns::AbstractVector; inplanes::Integer, - scalings::NTuple{2, Real} = (1, 1), - norm_layer = BatchNorm) - bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer) - for (stage_idx, residual_fn) in enumerate(residual_fns)] +# TODO - these builders need to be more flexible to potentially specify stuff like +# activation functions and reductions that don't change +function _get_builder(::typeof(dwsep_conv_bn), block_configs, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) + @assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument" + return idx -> dwsepconv_builder(block_configs, inplanes, idx, width_mult; norm_layer, + kwargs...) +end + +function _get_builder(::Union{typeof(mbconv), typeof(mbconv_m3)}, block_configs, + inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) + if isnothing(scalings) + return idx -> mbconv_builder(block_configs, inplanes, idx, width_mult; norm_layer, + kwargs...) + elseif isnothing(width_mult) + return idx -> mbconv_builder(block_configs, inplanes, idx, scalings; norm_layer, + kwargs...) + else + throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified")) + end +end + +function _get_builder(::typeof(fused_mbconv), block_configs, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer) + @assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument." + @assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument" + return idx -> fused_mbconv_builder(block_configs, inplanes, idx; norm_layer) +end + +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, + norm_layer = BatchNorm, kwargs...) + bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes; scalings, + width_mult, norm_layer, kwargs...)(idx) + for idx in eachindex(block_configs)] return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl index 2bf03746f..8343bf811 100644 --- a/src/convnets/builders/resblocks.jl +++ b/src/convnets/builders/resblocks.jl @@ -68,4 +68,4 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; return block, downsample end return get_layers -end \ No newline at end of file +end diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index c37fb8c93..947080481 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,5 +1,4 @@ -function efficientnet(block_configs::AbstractVector{<:Tuple}, - residual_fns::AbstractVector; inplanes::Integer, +function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer, scalings::NTuple{2, Real} = (1, 1), headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, @@ -10,12 +9,12 @@ function efficientnet(block_configs::AbstractVector{<:Tuple}, conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8), swish; norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns; - inplanes, scalings, norm_layer) + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; scalings, + norm_layer) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 495b8658f..aaf958025 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -1,14 +1,22 @@ # block configs for EfficientNet +# data organised as (k, c, e, s, n, r, a) for each stage +# k: kernel size +# c: output channels +# e: expansion ratio +# s: stride +# n: number of repeats +# r: reduction ratio for squeeze-excite layer +# a: activation function const EFFICIENTNET_BLOCK_CONFIGS = [ - # k, c, e, s, n, r, a - (3, 16, 1, 1, 1, 4, swish), - (3, 24, 6, 2, 2, 4, swish), - (5, 40, 6, 2, 2, 4, swish), - (3, 80, 6, 2, 3, 4, swish), - (5, 112, 6, 1, 3, 4, swish), - (5, 192, 6, 2, 4, 4, swish), - (3, 320, 6, 1, 1, 4, swish), + (mbconv, 3, 16, 1, 1, 1, 4, swish), + (mbconv, 3, 24, 6, 2, 2, 4, swish), + (mbconv, 5, 40, 6, 2, 2, 4, swish), + (mbconv, 3, 80, 6, 2, 3, 4, swish), + (mbconv, 5, 112, 6, 1, 3, 4, swish), + (mbconv, 5, 192, 6, 2, 4, 4, swish), + (mbconv, 3, 320, 6, 1, 1, 4, swish), ] + # Data is organised as (r, (w, d)) # r: image resolution # w: width scaling @@ -44,9 +52,8 @@ function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Intege nclasses::Integer = 1000) _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] - layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS, - fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS)); - inplanes = 32, scalings, inchannels, nclasses) + layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings, + inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index 69f67cc81..188875ebf 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,36 +1,39 @@ # block configs for EfficientNetv2 -# data organised as (k, c, e, s, n, r, a) -const EFFNETV2_CONFIGS = Dict(:small => [ - (3, 24, 1, 1, 2, nothing, swish), - (3, 48, 4, 2, 4, nothing, swish), - (3, 64, 4, 2, 4, nothing, swish), - (3, 128, 4, 2, 6, 4, swish), - (3, 160, 6, 1, 9, 4, swish), - (3, 256, 6, 2, 15, 4, swish)], - :medium => [ - (3, 24, 1, 1, 3, nothing, swish), - (3, 48, 4, 2, 5, nothing, swish), - (3, 80, 4, 2, 5, nothing, swish), - (3, 160, 4, 2, 7, 4, swish), - (3, 176, 6, 1, 14, 4, swish), - (3, 304, 6, 2, 18, 4, swish), - (3, 512, 6, 1, 5, 4, swish)], - :large => [ - (3, 32, 1, 1, 4, nothing, swish), - (3, 64, 4, 2, 7, nothing, swish), - (3, 96, 4, 2, 7, nothing, swish), - (3, 192, 4, 2, 10, 4, swish), - (3, 224, 6, 1, 19, 4, swish), - (3, 384, 6, 2, 25, 4, swish), - (3, 640, 6, 1, 7, 4, swish)], - :xlarge => [ - (3, 32, 1, 1, 4, nothing, swish), - (3, 64, 4, 2, 8, nothing, swish), - (3, 96, 4, 2, 8, nothing, swish), - (3, 192, 4, 2, 16, 4, swish), - (3, 384, 6, 1, 24, 4, swish), - (3, 512, 6, 2, 32, 4, swish), - (3, 768, 6, 1, 8, 4, swish)]) +# data organised as (k, c, e, s, n, r, a) for each stage +# k: kernel size +# c: output channels +# e: expansion ratio +# s: stride +# n: number of repeats +# r: reduction ratio for squeeze-excite layer - specified only for `mbconv` +# a: activation function +const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish), + (fused_mbconv, 3, 48, 4, 2, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 4, swish), + (mbconv, 3, 128, 4, 2, 6, 4, swish), + (mbconv, 3, 160, 6, 1, 9, 4, swish), + (mbconv, 3, 256, 6, 2, 15, 4, swish)], + :medium => [(fused_mbconv, 3, 24, 1, 1, 3, swish), + (fused_mbconv, 3, 48, 4, 2, 5, swish), + (fused_mbconv, 3, 80, 4, 2, 5, swish), + (mbconv, 3, 160, 4, 2, 7, 4, swish), + (mbconv, 3, 176, 6, 1, 14, 4, swish), + (mbconv, 3, 304, 6, 2, 18, 4, swish), + (mbconv, 3, 512, 6, 1, 5, 4, swish)], + :large => [(fused_mbconv, 3, 32, 1, 1, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 7, swish), + (fused_mbconv, 3, 96, 4, 2, 7, swish), + (mbconv, 3, 192, 4, 2, 10, 4, swish), + (mbconv, 3, 224, 6, 1, 19, 4, swish), + (mbconv, 3, 384, 6, 2, 25, 4, swish), + (mbconv, 3, 640, 6, 1, 7, 4, swish)], + :xlarge => [(fused_mbconv, 3, 32, 1, 1, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 8, swish), + (fused_mbconv, 3, 96, 4, 2, 8, swish), + (mbconv, 3, 192, 4, 2, 16, 4, swish), + (mbconv, 3, 384, 6, 1, 24, 4, swish), + (mbconv, 3, 512, 6, 2, 32, 4, swish), + (mbconv, 3, 768, 6, 1, 8, 4, swish)]) """ EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, @@ -55,11 +58,9 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - layers = efficientnet(EFFNETV2_CONFIGS[config], - vcat(fill(fused_mbconv_builder, 3), - fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); - inplanes = EFFNETV2_CONFIGS[config][1][2], headplanes = 1280, - inchannels, nclasses) + block_configs = EFFNETV2_CONFIGS[config] + layers = efficientnet(block_configs; inplanes = block_configs[1][3], + headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2-", config)) end diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 542edec81..024f7060b 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -25,33 +25,28 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] - for (dw, outchannels, stride, nrepeats) in config - outchannels = floor(Int, outchannels * width_mult) - for _ in 1:nrepeats - layer = dw ? - dwsep_conv_bn((3, 3), inchannels, outchannels, activation; - stride, pad = 1) : - conv_norm((3, 3), inchannels, outchannels, activation; stride, pad = 1) - append!(layers, layer) - inchannels = outchannels - end - end - return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate)) + # stem of the model + append!(layers, + conv_norm((3, 3), inchannels, config[1][3], activation; stride = 2, pad = 1)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(config, config[1][3]; width_mult) + append!(layers, cnn_stages(get_layers, block_repeats)) + return Chain(Chain(layers...), + create_classifier(config[end][3], nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 const MOBILENETV1_CONFIGS = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), + # k, c, s, r + (dwsep_conv_bn, 3, 64, 1, 1, relu6), + (dwsep_conv_bn, 3, 128, 2, 1, relu6), + (dwsep_conv_bn, 3, 128, 1, 1, relu6), + (dwsep_conv_bn, 3, 256, 2, 1, relu6), + (dwsep_conv_bn, 3, 256, 1, 1, relu6), + (dwsep_conv_bn, 3, 512, 2, 1, relu6), + (dwsep_conv_bn, 3, 512, 1, 5, relu6), + (dwsep_conv_bn, 3, 1024, 2, 1, relu6), + (dwsep_conv_bn, 3, 1024, 1, 1, relu6), ] """ diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index ad41ff967..c233b3d5e 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -33,29 +33,26 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, - fill(mbconv_builder, - length(block_configs)); - inplanes) + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][2], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3], 8), outplanes, relu6)) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 const MOBILENETV2_CONFIGS = [ - # k, c, e, s, n, r, a - (3, 16, 1, 1, 1, nothing, relu6), - (3, 24, 6, 2, 2, nothing, relu6), - (3, 32, 6, 2, 3, nothing, relu6), - (3, 64, 6, 2, 4, nothing, relu6), - (3, 96, 6, 1, 3, nothing, relu6), - (3, 160, 6, 2, 3, nothing, relu6), - (3, 320, 6, 1, 1, nothing, relu6), + # f, k, c, e, s, n r, a + (mbconv, 3, 16, 1, 1, 1, nothing, swish), + (mbconv, 3, 24, 6, 2, 2, nothing, swish), + (mbconv, 3, 32, 6, 2, 3, nothing, swish), + (mbconv, 3, 64, 6, 2, 4, nothing, swish), + (mbconv, 3, 96, 6, 1, 3, nothing, swish), + (mbconv, 3, 160, 6, 2, 3, nothing, swish), + (mbconv, 3, 320, 6, 1, 1, nothing, swish), ] """ diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 82265e125..ed8dda08b 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -25,77 +25,58 @@ Create a MobileNetv3 model. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, reduced_tail::Bool = false, - tail_dilated::Bool = false, dropout_rate = 0.2, + max_width::Integer = 1024, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) - explanes = 0 - nstages = length(configs) - reduced_divider = 1 # building inverted residual blocks - for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs) - dilation = 1 - if nstages - i <= 2 - if reduced_tail - reduced_divider = 2 - c /= reduced_divider - end - if tail_dilated - dilation = 2 - end - end - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - mbconv((k, k), inplanes, explanes, outplanes, activation; - stride, reduction, dilation)) - inplanes = outplanes - end + get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult) + append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers - headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8) - append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) + explanes = _round_channels(configs[end][3] * width_mult, 8) + midplanes = _round_channels(explanes * configs[end][4], 8) + headplanes = _round_channels(max_width * width_mult, 8) + append!(layers, conv_norm((1, 1), explanes, midplanes, hardswish)) return Chain(Chain(layers...), - create_classifier(explanes, headplanes, nclasses, + create_classifier(midplanes, headplanes, nclasses, (hardswish, identity); dropout_rate)) end # Layer configurations for small and large models for MobileNetv3 +# Data is organised as (f, k, c, e, s, n, r, a) +# f: mbconv block function - we use `mbconv_m3` for all blocks +# k: kernel size +# c: output channels +# e: expansion factor +# s: stride +# n: number of repeats +# r: squeeze and excite reduction factor +# a: activation function const MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, r, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), + # f, k, c, e, s, n, r, a + (mbconv_m3, 3, 16, 1, 2, 1, 4, relu), + (mbconv_m3, 3, 24, 4.5, 2, 1, nothing, relu), + (mbconv_m3, 3, 24, 3.67, 1, 1, nothing, relu), + (mbconv_m3, 5, 40, 4, 2, 1, 4, hardswish), + (mbconv_m3, 5, 40, 6, 1, 2, 4, hardswish), + (mbconv_m3, 5, 48, 3, 1, 2, 4, hardswish), + (mbconv_m3, 5, 96, 6, 1, 3, 4, hardswish), ], :large => [ - # k, t, c, r, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), + # f, k, c, e, s, n, r, a + (mbconv_m3, 3, 16, 1, 1, 1, nothing, relu), + (mbconv_m3, 3, 24, 4, 2, 1, nothing, relu), + (mbconv_m3, 3, 24, 3, 1, 1, nothing, relu), + (mbconv_m3, 5, 40, 3, 2, 1, 4, relu), + (mbconv_m3, 5, 40, 3, 1, 2, 4, relu), + (mbconv_m3, 3, 80, 6, 2, 1, nothing, hardswish), + (mbconv_m3, 3, 80, 2.5, 1, 1, nothing, hardswish), + (mbconv_m3, 3, 80, 2.3, 1, 2, nothing, hardswish), + (mbconv_m3, 3, 112, 6, 1, 2, 4, hardswish), + (mbconv_m3, 5, 160, 6, 1, 3, 4, hardswish), ]) """ diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 83ab9da04..9e54ec06d 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -71,7 +71,8 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm) + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, + revnorm) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, stride, pad = 1, groups = cardinality) conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) @@ -92,7 +93,8 @@ function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer norm_layer = BatchNorm, revnorm::Bool = false) pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm)...) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, + revnorm)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -161,8 +163,7 @@ on how to use this function. function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, replace_pool::Bool = false, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false) - @assert stem_type in [:default, :deep, :deep_tiered] - "Stem type must be one of [:default, :deep, :deep_tiered]" + _checkconfig(stem_type, [:default, :deep, :deep_tiered]) # Main stem deep_stem = stem_type == :deep || stem_type == :deep_tiered inplanes = deep_stem ? stem_width * 2 : 64 @@ -199,7 +200,7 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte connection, classifier_fn) # Build stages of the ResNet stage_blocks = cnn_stages(get_layers, block_repeats, connection) - backbone = Chain(stem, stage_blocks) + backbone = Chain(stem, stage_blocks...) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] return Chain(backbone, classifier_fn(nfeaturemaps)) @@ -207,12 +208,14 @@ end function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); - cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, + cardinality::Integer = 1, base_width::Integer = 64, + inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, + use_conv::Bool = false, drop_block_rate = nothing, + drop_path_rate = nothing, dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) @@ -234,12 +237,12 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, planes_fn = resnet_planes, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert isnothing(drop_block_rate) - "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" - @assert isnothing(drop_path_rate) - "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" - @assert reduction_factor == 1 - "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" + @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. + Set `drop_block_rate` to nothing." + @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. + Set `drop_path_rate` to nothing." + @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. + Set `reduction_factor` to 1." get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 9bdf1f913..e1b5197f0 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -28,7 +28,7 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mbconv.jl") -export mbconv, fused_mbconv +export mbconv, mbconv_m3, fused_mbconv include("mlp.jl") export mlp_block, gated_mlp_block diff --git a/src/layers/conv.jl b/src/layers/conv.jl index d81d76c9c..e49611280 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -64,7 +64,6 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end - function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...) inplanes, outplanes = ch @@ -78,50 +77,3 @@ function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = r return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, eps = 1.0f-3, kwargs...) end - -""" - dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, - stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) - -Create a depthwise separable convolution chain as used in MobileNetv1. -This is sequence of layers: - - - a `kernel_size` depthwise convolution from `inplanes => inplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise - `activation` is applied to the convolution output) - - a `kernel_size` convolution from `inplanes => outplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise - `activation` is applied to the convolution output) - -See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - -# Arguments - - - `kernel_size`: size of the convolution kernel (tuple) - - `inplanes`: number of input feature maps - - `outplanes`: number of output feature maps - - `activation`: the activation function for the final layer - - `revnorm`: set to `true` to place the batch norm before the convolution - - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and - second convolution - - `bias`: a tuple of two booleans to specify whether to use bias for the first and second - convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and - `use_norm[1] == true`. - - `stride`: stride of the first convolution kernel - - `pad`: padding of the first convolution kernel - - `dilation`: dilation of the first convolution kernel - - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) -""" -function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, - outplanes::Integer, activation = relu; eps::Float32 = 1.0f-5, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), - bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) - return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, - revnorm, use_norm = use_norm[1], stride, bias = bias[1], - groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; eps, - revnorm, use_norm = use_norm[2], bias = bias[2])) -end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 387b562ef..edff55234 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -93,10 +93,8 @@ end trainable(a::DropBlock) = (;) function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_scale) - @assert 0 ≤ drop_block_prob ≤ 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert 0 ≤ gamma_scale ≤ 1 - return "gamma_scale must be between 0 and 1, got $gamma_scale" + @assert 0≤drop_block_prob≤1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0≤gamma_scale≤1 "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index eeac44dbd..74938436d 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -1,3 +1,49 @@ +""" + dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, [bias, weight, init]) + +Create a depthwise separable convolution chain as used in MobileNetv1. +This is sequence of layers: + + - a `kernel_size` depthwise convolution from `inplanes => inplanes` + - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise + `activation` is applied to the convolution output) + - a `kernel_size` convolution from `inplanes => outplanes` + - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise + `activation` is applied to the convolution output) + +See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). + +# Arguments + + - `kernel_size`: size of the convolution kernel (tuple) + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + - `activation`: the activation function for the final layer + - `revnorm`: set to `true` to place the batch norm before the convolution + - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and + second convolution + - `bias`: a tuple of two booleans to specify whether to use bias for the first and second + convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and + `use_norm[1] == true`. + - `stride`: stride of the first convolution kernel + - `pad`: padding of the first convolution kernel + - `dilation`: dilation of the first convolution kernel + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) +""" +function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, + revnorm, use_norm = use_norm[1], stride, bias = bias[1], + groups = inplanes, kwargs...), + conv_norm((1, 1), inplanes, outplanes, activation; eps, + revnorm, use_norm = use_norm[2], bias = bias[2])) +end + # TODO add support for stochastic depth to mbconv and fused_mbconv """ mbconv(kernel_size, inplanes::Integer, explanes::Integer, @@ -21,8 +67,13 @@ Create a basic inverted residual block for MobileNet variants function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, - norm_layer = BatchNorm) + norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, + no_skip::Bool = false) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + if !isnothing(momentum) + @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) + end layers = [] # expand if inplanes != explanes @@ -30,7 +81,6 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end # depthwise - stride = dilation > 1 ? 1 : stride append!(layers, conv_norm(kernel_size, explanes, explanes, activation; norm_layer, stride, dilation, pad = SamePad(), groups = explanes)) @@ -42,25 +92,57 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, end # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) - return Chain(layers...) + use_skip = stride == 1 && inplanes == outplanes && !no_skip + return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) +end + +function mbconv_m3(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, + norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, + no_skip::Bool = false) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + if !isnothing(momentum) + @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) + end + layers = [] + # expand + if inplanes != explanes + append!(layers, + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) + end + # depthwise + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, dilation, pad = SamePad(), groups = explanes)) + # squeeze-excite layer + if !isnothing(reduction) + push!(layers, + squeeze_excite(explanes, _round_channels(explanes ÷ reduction, 8); + activation, + gate_activation = hardσ)) + end + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) + use_skip = stride == 1 && inplanes == outplanes && !no_skip + return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) end function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, norm_layer = BatchNorm) + stride::Integer, norm_layer = BatchNorm, no_skip::Bool = false) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `fused_mbconv`" layers = [] + # fused expand + explanes = explanes == inplanes ? outplanes : explanes + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) if explanes != inplanes - # fused expand - append!(layers, - conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, - pad = SamePad())) # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) - else - append!(layers, - conv_norm(kernel_size, inplanes, outplanes, activation; pad = SamePad(), - norm_layer, stride)) end - return Chain(layers...) + use_skip = stride == 1 && inplanes == outplanes && !no_skip + return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index db034d1af..c9f4873e9 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -22,6 +22,8 @@ Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE- - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer """ +# TODO look into a `get_norm_act` layer that will return a closure over the norm layer +# with the activation function passed in when the norm layer is not `identity` function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; norm_layer = planes -> identity, activation = relu, gate_activation = sigmoid) @@ -34,9 +36,8 @@ function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; gate_activation] return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end - -function squeeze_excite(inplanes::Integer; reduction::Integer = 16, rd_divisor::Integer = 8, - kwargs...) +function squeeze_excite(inplanes::Integer; reduction::Integer = 16, + rd_divisor::Integer = 8, kwargs...) return squeeze_excite(inplanes, _round_channels(inplanes ÷ reduction, rd_divisor, 0); kwargs...) end @@ -54,6 +55,5 @@ Effective squeeze-and-excitation layer. """ function effective_squeeze_excite(inplanes::Integer; gate_activation = sigmoid) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes, inplanes), - gate_activation), .*) + Conv((1, 1), inplanes => inplanes, gate_activation)), .*) end From 510e913feea036ab4c98623b6d1edd2b3491dcda Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 20 Aug 2022 16:18:56 +0530 Subject: [PATCH 30/34] Add MNASNet Also general cleanup --- Project.toml | 2 +- src/Metalhead.jl | 7 +- src/convnets/builders/mbconv.jl | 23 +++-- src/convnets/mobilenets/mnasnet.jl | 121 +++++++++++++++++++++++++ src/convnets/mobilenets/mobilenetv1.jl | 23 +++-- src/convnets/mobilenets/mobilenetv2.jl | 29 ++++-- src/convnets/mobilenets/mobilenetv3.jl | 46 +++++----- src/convnets/resnets/core.jl | 6 +- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 7 +- src/layers/mbconv.jl | 41 +-------- src/utilities.jl | 2 +- 12 files changed, 214 insertions(+), 95 deletions(-) create mode 100644 src/convnets/mobilenets/mnasnet.jl diff --git a/Project.toml b/Project.toml index 691003944..9664dc79d 100644 --- a/Project.toml +++ b/Project.toml @@ -23,10 +23,10 @@ Flux = "0.13" Functors = "0.2, 0.3" CUDA = "3" ChainRulesCore = "1" -PartialFunctions = "1" MLUtils = "0.2.10" NNlib = "0.8" NNlibCUDA = "0.2" +PartialFunctions = "1" julia = "1.6" [publish] diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 8c8800e84..f9be49db1 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -46,6 +46,7 @@ include("convnets/efficientnets/efficientnetv2.jl") include("convnets/mobilenets/mobilenetv1.jl") include("convnets/mobilenets/mobilenetv2.jl") include("convnets/mobilenets/mobilenetv3.jl") +include("convnets/mobilenets/mnasnet.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") @@ -69,14 +70,16 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, - SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, EfficientNetv2, + SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, + EfficientNet, EfficientNetv2, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :EfficientNetv2, + :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, + :EfficientNet, :EfficientNetv2, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index c90b8215c..06777826a 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -1,13 +1,13 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, width_mult::Number; norm_layer = BatchNorm, kwargs...) - _, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] + block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] outplanes = floor(Int, outplanes * width_mult) inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes stride = block_idx == 1 ? stride : 1 - block = Chain(dwsep_conv_bn((k, k), inplanes, outplanes, activation; - stride, pad = SamePad(), norm_layer, kwargs...)...) + block = Chain(block_fn((k, k), inplanes, outplanes, activation; + stride, pad = SamePad(), norm_layer, kwargs...)...) return (block,) end return get_layers, nrepeats @@ -15,15 +15,15 @@ end function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real}; norm_layer = BatchNorm, - round_fn = planes -> _round_channels(planes, 8), kwargs...) + divisor::Integer = 8, kwargs...) width_mult, depth_mult = scalings block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] - inplanes = round_fn(inplanes * width_mult) - outplanes = _round_channels(outplanes * width_mult, 8) + inplanes = _round_channels(inplanes * width_mult, divisor) + outplanes = _round_channels(outplanes * width_mult, divisor) function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) + explanes = _round_channels(inplanes * expansion, divisor) stride = block_idx == 1 ? stride : 1 block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, stride, reduction, no_skip = true, kwargs...) @@ -51,14 +51,14 @@ end function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer; norm_layer = BatchNorm, kwargs...) - _, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] + block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 - block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation; - norm_layer, stride, no_skip = true, kwargs...) + block = block_fn((k, k), inplanes, explanes, outplanes, activation; + norm_layer, stride, no_skip = true, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, nrepeats @@ -74,8 +74,7 @@ function _get_builder(::typeof(dwsep_conv_bn), block_configs, inplanes::Integer; kwargs...) end -function _get_builder(::Union{typeof(mbconv), typeof(mbconv_m3)}, block_configs, - inplanes::Integer; +function _get_builder(::typeof(mbconv), block_configs, inplanes::Integer; scalings::Union{Nothing, NTuple{2, Real}} = nothing, width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) if isnothing(scalings) diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl new file mode 100644 index 000000000..e34327105 --- /dev/null +++ b/src/convnets/mobilenets/mnasnet.jl @@ -0,0 +1,121 @@ +# momentum used for BatchNorm as per Tensorflow implementation +const _MNASNET_BN_MOMENTUM = 0.0003f0 + +""" + mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real, + max_width = 1280, dropout_rate = 0.2, inchannels::Integer = 3, + nclasses::Integer = 1000) + +Create an MNASNet model with the specified configuration. +([reference](https://arxiv.org/abs/1807.11626)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper) + - `max_width`: The maximum number of feature maps in any layer of the network + - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes +""" +function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real, + max_width::Integer = 1280, inplanes::Integer = 32, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + # building first layer + inplanes = _round_channels(inplanes * width_mult, 8) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, relu; stride = 2, pad = 1, + momentum = _MNASNET_BN_MOMENTUM)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, + momentum = _MNASNET_BN_MOMENTUM) + append!(layers, cnn_stages(get_layers, block_repeats, +)) + # building last layers + outplanes = _round_channels(block_configs[end][3] * width_mult, 8) + headplanes = _round_channels(max_width * max(1, width_mult), 8) + append!(layers, + conv_norm((1, 1), outplanes, headplanes, relu; momentum = _MNASNET_BN_MOMENTUM)) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) +end + +# Layer configurations for MNasNet +# f: block function - we use `dwsep_conv_bn` for the first block and `mbconv` for the rest +# k: kernel size +# c: output channels +# e: expansion factor - only used for `mbconv` +# s: stride +# n: number of repeats +# r: reduction factor - only used for `mbconv` +# a: activation function +const MNASNET_CONFIGS = Dict(:B1 => (32, + [ + # f, k, c, (e,) s, n, (r,) a + (dwsep_conv_bn, 3, 16, 1, 1, relu), + (mbconv, 3, 24, 3, 2, 3, nothing, relu), + (mbconv, 5, 40, 3, 2, 3, nothing, relu), + (mbconv, 5, 80, 6, 2, 3, nothing, relu), + (mbconv, 3, 96, 6, 1, 2, nothing, relu), + (mbconv, 5, 192, 6, 2, 4, nothing, relu), + (mbconv, 3, 320, 6, 1, 1, nothing, relu), + ]), + :A1 => (32, + [ + (dwsep_conv_bn, 3, 16, 1, 1, relu), + (mbconv, 3, 24, 6, 2, 2, nothing, relu), + (mbconv, 5, 40, 3, 2, 3, 4, relu), + (mbconv, 3, 80, 6, 2, 4, nothing, relu), + (mbconv, 3, 112, 6, 1, 2, 4, relu), + (mbconv, 5, 160, 6, 2, 3, 4, relu), + (mbconv, 3, 320, 6, 1, 1, nothing, relu), + ]) + # TODO small doesn't work yet - need to fix squeeze and excite + # channel calculations somehow + # :small => (8, + # [ + # (dwsep_conv_bn, 3, 8, 1, 1, relu), + # (mbconv, 3, 16, 3, 2, 1, nothing, relu), + # (mbconv, 3, 16, 6, 2, 2, nothing, relu), + # (mbconv, 5, 32, 6, 2, 4, 4, relu), + # (mbconv, 3, 32, 6, 1, 2, 3, relu), + # (mbconv, 5, 88, 6, 2, 3, 3, relu), + # (mbconv, 3, 144, 6, 1, 1, nothing, relu),]), + ) + +""" + MNASNet(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false, + nclasses::Integer = 1000) + +Creates a MNASNet model with the specified configuration. +([reference](https://arxiv.org/abs/1807.11626)) + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes + +!!! warning + + `MNASNet` does not currently support pretrained weights. + +See also [`mnasnet`](#). +""" +struct MNASNet + layers::Any +end +@functor MNASNet + +function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, keys(MNASNET_CONFIGS)) + inplanes, block_configs = MNASNET_CONFIGS[config] + layers = mnasnet(block_configs; width_mult, inplanes, inchannels, nclasses) + if pretrain + load_pretrained!(layers, "mnasnet$(width_mult)") + end + return MNASNet(layers) +end diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 024f7060b..2b4d67110 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -21,7 +21,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ -function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; +function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1, activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] @@ -36,8 +36,14 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; end # Layer configurations for MobileNetv1 +# f: block function - we use `dwsep_conv_bn` for all blocks +# k: kernel size +# c: output channels +# s: stride +# n: number of repeats +# a: activation function const MOBILENETV1_CONFIGS = [ - # k, c, s, r + # f, k, c, s, n, a (dwsep_conv_bn, 3, 64, 1, 1, relu6), (dwsep_conv_bn, 3, 128, 2, 1, relu6), (dwsep_conv_bn, 3, 128, 1, 1, relu6), @@ -50,23 +56,26 @@ const MOBILENETV1_CONFIGS = [ ] """ - MobileNetv1(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false, - nclasses::Integer = 1000) + MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model with the baseline configuration ([reference](https://arxiv.org/abs/1704.04861v1)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inchannels`: The number of input channels. - `nclasses`: The number of output classes -See also [`Metalhead.mobilenetv1`](#). +!!! warning + + `MobileNetv1` does not currently support pretrained weights. + +See also [`mobilenetv1`](#). """ struct MobileNetv1 layers::Any diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index c233b3d5e..fdc536ca7 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -33,19 +33,27 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult) + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, + divisor) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers - outplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3], 8), - outplanes, relu6)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) + outplanes = _round_channels(block_configs[end][3], divisor) + headplanes = _round_channels(max_width * max(1, width_mult), divisor) + append!(layers, conv_norm((1, 1), outplanes, headplanes, relu6)) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv2 +# f: block function - we use `dwsep_conv_bn` for the first block and `mbconv` for the rest +# k: kernel size +# c: output channels +# e: expansion factor +# s: stride +# n: number of repeats +# r: reduction factor +# a: activation function const MOBILENETV2_CONFIGS = [ - # f, k, c, e, s, n r, a + # f, k, c, e, s, n, r, a (mbconv, 3, 16, 1, 1, 1, nothing, swish), (mbconv, 3, 24, 6, 2, 2, nothing, swish), (mbconv, 3, 32, 6, 2, 3, nothing, swish), @@ -61,7 +69,6 @@ const MOBILENETV2_CONFIGS = [ Create a MobileNetv2 model with the specified configuration. ([reference](https://arxiv.org/abs/1801.04381)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments @@ -72,7 +79,11 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes -See also [`Metalhead.mobilenetv2`](#). +!!! warning + + `MobileNetv2` does not currently support pretrained weights. + +See also [`mobilenetv2`](#). """ struct MobileNetv2 layers::Any diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index ed8dda08b..bddef7aae 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -33,7 +33,8 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult) + get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult, + se_from_explanes = true) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers explanes = _round_channels(configs[end][3] * width_mult, 8) @@ -46,8 +47,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, end # Layer configurations for small and large models for MobileNetv3 -# Data is organised as (f, k, c, e, s, n, r, a) -# f: mbconv block function - we use `mbconv_m3` for all blocks +# f: mbconv block function - we use `mbconv` for all blocks # k: kernel size # c: output channels # e: expansion factor @@ -57,26 +57,26 @@ end # a: activation function const MOBILENETV3_CONFIGS = Dict(:small => [ # f, k, c, e, s, n, r, a - (mbconv_m3, 3, 16, 1, 2, 1, 4, relu), - (mbconv_m3, 3, 24, 4.5, 2, 1, nothing, relu), - (mbconv_m3, 3, 24, 3.67, 1, 1, nothing, relu), - (mbconv_m3, 5, 40, 4, 2, 1, 4, hardswish), - (mbconv_m3, 5, 40, 6, 1, 2, 4, hardswish), - (mbconv_m3, 5, 48, 3, 1, 2, 4, hardswish), - (mbconv_m3, 5, 96, 6, 1, 3, 4, hardswish), + (mbconv, 3, 16, 1, 2, 1, 4, relu), + (mbconv, 3, 24, 4.5, 2, 1, nothing, relu), + (mbconv, 3, 24, 3.67, 1, 1, nothing, relu), + (mbconv, 5, 40, 4, 2, 1, 4, hardswish), + (mbconv, 5, 40, 6, 1, 2, 4, hardswish), + (mbconv, 5, 48, 3, 1, 2, 4, hardswish), + (mbconv, 5, 96, 6, 1, 3, 4, hardswish), ], :large => [ # f, k, c, e, s, n, r, a - (mbconv_m3, 3, 16, 1, 1, 1, nothing, relu), - (mbconv_m3, 3, 24, 4, 2, 1, nothing, relu), - (mbconv_m3, 3, 24, 3, 1, 1, nothing, relu), - (mbconv_m3, 5, 40, 3, 2, 1, 4, relu), - (mbconv_m3, 5, 40, 3, 1, 2, 4, relu), - (mbconv_m3, 3, 80, 6, 2, 1, nothing, hardswish), - (mbconv_m3, 3, 80, 2.5, 1, 1, nothing, hardswish), - (mbconv_m3, 3, 80, 2.3, 1, 2, nothing, hardswish), - (mbconv_m3, 3, 112, 6, 1, 2, 4, hardswish), - (mbconv_m3, 5, 160, 6, 1, 3, 4, hardswish), + (mbconv, 3, 16, 1, 1, 1, nothing, relu), + (mbconv, 3, 24, 4, 2, 1, nothing, relu), + (mbconv, 3, 24, 3, 1, 1, nothing, relu), + (mbconv, 5, 40, 3, 2, 1, 4, relu), + (mbconv, 5, 40, 3, 1, 2, 4, relu), + (mbconv, 3, 80, 6, 2, 1, nothing, hardswish), + (mbconv, 3, 80, 2.5, 1, 1, nothing, hardswish), + (mbconv, 3, 80, 2.3, 1, 2, nothing, hardswish), + (mbconv, 3, 112, 6, 1, 2, 4, hardswish), + (mbconv, 5, 160, 6, 1, 3, 4, hardswish), ]) """ @@ -97,7 +97,11 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `inchannels`: number of input channels - `nclasses`: the number of output classes -See also [`Metalhead.mobilenetv3`](#). +!!! warning + + `MobileNetv3` does not currently support pretrained weights. + +See also [`mobilenetv3`](#). """ struct MobileNetv3 layers::Any diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 9e54ec06d..8f5846592 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -208,14 +208,12 @@ end function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); - cardinality::Integer = 1, base_width::Integer = 64, - inplanes::Integer = 64, + cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = nothing, - drop_path_rate = nothing, + use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index e1b5197f0..9bdf1f913 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -28,7 +28,7 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mbconv.jl") -export mbconv, mbconv_m3, fused_mbconv +export mbconv, fused_mbconv include("mlp.jl") export mlp_block, gated_mlp_block diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e49611280..f555877b7 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -33,7 +33,7 @@ Create a convolution + batch normalization pair with activation. function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, - bias = !use_norm, kwargs...) + momentum::Union{Nothing, Number} = nothing, bias = !use_norm, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -59,6 +59,11 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, activations = (conv = activation, bn = identity) end end + # handle momentum for BatchNorm + if !isnothing(momentum) + @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) + end # layers layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), norm_layer(normplanes, activations.bn; ϵ = eps)] diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index 74938436d..e740480f8 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -68,8 +68,9 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, - no_skip::Bool = false) + se_from_explanes::Bool = false, divisor::Integer = 8, no_skip::Bool = false) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + # handle momentum for BatchNorm if !isnothing(momentum) @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) @@ -86,42 +87,10 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, stride, dilation, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) + squeeze_planes = _round_channels((se_from_explanes ? explanes : inplanes) ÷ + reduction, divisor) push!(layers, - squeeze_excite(explanes, max(1, inplanes ÷ reduction); activation, - gate_activation = hardσ)) - end - # project - append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) - use_skip = stride == 1 && inplanes == outplanes && !no_skip - return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) -end - -function mbconv_m3(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, - norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, - no_skip::Bool = false) - @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" - if !isnothing(momentum) - @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" - norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) - end - layers = [] - # expand - if inplanes != explanes - append!(layers, - conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) - end - # depthwise - append!(layers, - conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, dilation, pad = SamePad(), groups = explanes)) - # squeeze-excite layer - if !isnothing(reduction) - push!(layers, - squeeze_excite(explanes, _round_channels(explanes ÷ reduction, 8); - activation, - gate_activation = hardσ)) + squeeze_excite(explanes, squeeze_planes; activation, gate_activation = hardσ)) end # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) diff --git a/src/utilities.jl b/src/utilities.jl index 09074b0e8..f208ce360 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -3,7 +3,7 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants -function _round_channels(channels, divisor, min_value = divisor) +function _round_channels(channels::Integer, divisor::Integer, min_value::Integer = divisor) new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) # Make sure that round down does not go down by more than 10% return new_channels < 0.9 * channels ? new_channels + divisor : new_channels From ab2a15eeb56fa263134432fe49eb938726adfcc7 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 20 Aug 2022 18:14:36 +0530 Subject: [PATCH 31/34] Add tests for MNASNet Some fixes --- .github/workflows/CI.yml | 2 +- src/convnets/mobilenets/mobilenetv1.jl | 2 +- src/utilities.jl | 2 +- test/convnets.jl | 71 ++++++++++++++++---------- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5304bc317..f413b4bdc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,7 +27,7 @@ jobs: - x64 suite: - '["AlexNet", "VGG"]' - - '["GoogLeNet", "SqueezeNet", "MobileNet"]' + - '["GoogLeNet", "SqueezeNet", "MobileNets"]' - '"EfficientNet"' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 2b4d67110..5fafbaf8a 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -84,7 +84,7 @@ end function MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses) + layers = mobilenetv1(MOBILENETV1_CONFIGS; width_mult, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv1")) end diff --git a/src/utilities.jl b/src/utilities.jl index f208ce360..11deb4373 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -3,7 +3,7 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants -function _round_channels(channels::Integer, divisor::Integer, min_value::Integer = divisor) +function _round_channels(channels::Number, divisor::Integer, min_value::Integer = divisor) new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) # Make sure that round down does not go down by more than 10% return new_channels < 0.9 * channels ? new_channels + divisor : new_channels diff --git a/test/convnets.jl b/test/convnets.jl index 34bbb5121..1a2b0562e 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -278,39 +278,54 @@ end end @testset "MobileNet" verbose = true begin - @testset "MobileNetv1" begin - m = MobileNetv1() - @test size(m(x_224)) == (1000, 1) - if MobileNetv1 in PRETRAINED_MODELS - @test acctest(MobileNetv1(pretrain = true)) - else - @test_throws ArgumentError MobileNetv1(pretrain = true) - end - @test gradtest(m, x_224) - end - _gc() - @testset "MobileNetv2" begin - m = MobileNetv2() - @test size(m(x_224)) == (1000, 1) - if MobileNetv2 in PRETRAINED_MODELS - @test acctest(MobileNetv2(pretrain = true)) - else - @test_throws ArgumentError MobileNetv2(pretrain = true) + for width_mult in [0.5, 0.75, 1, 1.3] + @testset "MobileNetv1" begin + m = MobileNetv1(width_mult) + @test size(m(x_224)) == (1000, 1) + if (MobileNetv1, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv1(pretrain = true)) + else + @test_throws ArgumentError MobileNetv1(pretrain = true) + end + @test gradtest(m, x_224) end - @test gradtest(m, x_224) - end - _gc() - @testset "MobileNetv3" verbose = true begin - @testset for config in [:small, :large] - m = MobileNetv3(config) + _gc() + @testset "MobileNetv2" begin + m = MobileNetv2(width_mult) @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) + if (MobileNetv2, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv2(pretrain = true)) else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) + @test_throws ArgumentError MobileNetv2(pretrain = true) end @test gradtest(m, x_224) - _gc() + end + _gc() + @testset "MobileNetv3" verbose = true begin + @testset for config in [:small, :large] + m = MobileNetv3(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) + else + @test_throws ArgumentError MobileNetv3(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + @testset "MNASNet" verbose = true begin + @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) + else + @test_throws ArgumentError MNASNet(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end end end end From 2d310a96e3c0ac07ba64374b7a37576584043c80 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 23 Aug 2022 16:02:40 +0530 Subject: [PATCH 32/34] Final cleanup, hopefully SE layer calculations finally work --- .github/workflows/CI.yml | 2 +- src/convnets/builders/mbconv.jl | 59 +++++++------- src/convnets/efficientnets/efficientnet.jl | 2 +- src/convnets/efficientnets/efficientnetv2.jl | 2 +- src/convnets/mobilenets/mnasnet.jl | 38 ++++----- src/convnets/mobilenets/mobilenetv3.jl | 6 +- src/layers/conv.jl | 7 +- src/layers/mbconv.jl | 83 ++++++++++++++------ src/layers/selayers.jl | 7 +- src/utilities.jl | 2 +- test/convnets.jl | 78 +++++++++--------- 11 files changed, 155 insertions(+), 131 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f413b4bdc..f796a66d7 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,7 +30,7 @@ jobs: - '["GoogLeNet", "SqueezeNet", "MobileNets"]' - '"EfficientNet"' - 'r"/*/ResNet*"' - - '[r"ResNeXt", r"SEResNet"]' + - 'r"/*/SEResNet*"' - '[r"Res2Net", r"Res2NeXt"]' - '"Inception"' - '"DenseNet"' diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index 06777826a..079f05011 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -1,5 +1,5 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, - width_mult::Number; norm_layer = BatchNorm, kwargs...) + width_mult::Real; norm_layer = BatchNorm, kwargs...) block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] outplanes = floor(Int, outplanes * width_mult) inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] @@ -15,9 +15,14 @@ end function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real}; norm_layer = BatchNorm, - divisor::Integer = 8, kwargs...) + divisor::Integer = 8, se_from_explanes::Bool = false, + kwargs...) width_mult, depth_mult = scalings block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + # calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes + if !isnothing(reduction) + reduction = !se_from_explanes ? reduction * expansion : reduction + end inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] inplanes = _round_channels(inplanes * width_mult, divisor) outplanes = _round_channels(outplanes * width_mult, divisor) @@ -26,7 +31,7 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, explanes = _round_channels(inplanes * expansion, divisor) stride = block_idx == 1 ? stride : 1 block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, - stride, reduction, no_skip = true, kwargs...) + stride, reduction, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, ceil(Int, nrepeats * depth_mult) @@ -34,23 +39,12 @@ end function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, width_mult::Real; norm_layer = BatchNorm, kwargs...) - block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] - inplanes = _round_channels(inplanes * width_mult, 8) - outplanes = _round_channels(outplanes * width_mult, 8) - function get_layers(block_idx::Integer) - inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) - stride = block_idx == 1 ? stride : 1 - block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, - stride, reduction, no_skip = true, kwargs...) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) - end - return get_layers, nrepeats + return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1); + norm_layer, kwargs...) end -function fused_mbconv_builder(block_configs, inplanes::Integer, - stage_idx::Integer; norm_layer = BatchNorm, kwargs...) +function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer; + norm_layer = BatchNorm, kwargs...) block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] function get_layers(block_idx::Integer) @@ -58,7 +52,7 @@ function fused_mbconv_builder(block_configs, inplanes::Integer, explanes = _round_channels(inplanes * expansion, 8) stride = block_idx == 1 ? stride : 1 block = block_fn((k, k), inplanes, explanes, outplanes, activation; - norm_layer, stride, no_skip = true, kwargs...) + norm_layer, stride, kwargs...) return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) end return get_layers, nrepeats @@ -66,42 +60,45 @@ end # TODO - these builders need to be more flexible to potentially specify stuff like # activation functions and reductions that don't change -function _get_builder(::typeof(dwsep_conv_bn), block_configs, inplanes::Integer; +function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; scalings::Union{Nothing, NTuple{2, Real}} = nothing, width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) @assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument" - return idx -> dwsepconv_builder(block_configs, inplanes, idx, width_mult; norm_layer, - kwargs...) + return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, + kwargs...) end -function _get_builder(::typeof(mbconv), block_configs, inplanes::Integer; +function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; scalings::Union{Nothing, NTuple{2, Real}} = nothing, width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) if isnothing(scalings) - return idx -> mbconv_builder(block_configs, inplanes, idx, width_mult; norm_layer, - kwargs...) + return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, + kwargs...) elseif isnothing(width_mult) - return idx -> mbconv_builder(block_configs, inplanes, idx, scalings; norm_layer, - kwargs...) + return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer, + kwargs...) else throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified")) end end -function _get_builder(::typeof(fused_mbconv), block_configs, inplanes::Integer; +function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; scalings::Union{Nothing, NTuple{2, Real}} = nothing, width_mult::Union{Nothing, Number} = nothing, norm_layer) @assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument." @assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument" - return idx -> fused_mbconv_builder(block_configs, inplanes, idx; norm_layer) + return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer) end function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer; scalings::Union{Nothing, NTuple{2, Real}} = nothing, width_mult::Union{Nothing, Number} = nothing, norm_layer = BatchNorm, kwargs...) - bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes; scalings, - width_mult, norm_layer, kwargs...)(idx) + bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings, + width_mult, norm_layer, kwargs...) for idx in eachindex(block_configs)] return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index aaf958025..5eb81b21d 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -1,5 +1,4 @@ # block configs for EfficientNet -# data organised as (k, c, e, s, n, r, a) for each stage # k: kernel size # c: output channels # e: expansion ratio @@ -7,6 +6,7 @@ # n: number of repeats # r: reduction ratio for squeeze-excite layer # a: activation function +# Data is organised as (k, c, e, s, n, r, a) const EFFICIENTNET_BLOCK_CONFIGS = [ (mbconv, 3, 16, 1, 1, 1, 4, swish), (mbconv, 3, 24, 6, 2, 2, 4, swish), diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index 188875ebf..ff64eea23 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -1,5 +1,4 @@ # block configs for EfficientNetv2 -# data organised as (k, c, e, s, n, r, a) for each stage # k: kernel size # c: output channels # e: expansion ratio @@ -7,6 +6,7 @@ # n: number of repeats # r: reduction ratio for squeeze-excite layer - specified only for `mbconv` # a: activation function +# Data organised as (f, k, c, e, s, n, (r,) a) for each stage const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish), (fused_mbconv, 3, 48, 4, 2, 4, swish), (fused_mbconv, 3, 64, 4, 2, 4, swish), diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index e34327105..1445e1a4b 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -2,7 +2,7 @@ const _MNASNET_BN_MOMENTUM = 0.0003f0 """ - mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real, + mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width = 1280, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -18,24 +18,27 @@ Create an MNASNet model with the specified configuration. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ -function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real, +function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1280, inplanes::Integer = 32, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) + # norm layer for MNASNet is different from other models + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = _MNASNET_BN_MOMENTUM, + kwargs...) # building first layer inplanes = _round_channels(inplanes * width_mult, 8) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes, relu; stride = 2, pad = 1, - momentum = _MNASNET_BN_MOMENTUM)) + norm_layer)) # building inverted residual blocks get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, - momentum = _MNASNET_BN_MOMENTUM) + norm_layer) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(block_configs[end][3] * width_mult, 8) headplanes = _round_channels(max_width * max(1, width_mult), 8) append!(layers, - conv_norm((1, 1), outplanes, headplanes, relu; momentum = _MNASNET_BN_MOMENTUM)) + conv_norm((1, 1), outplanes, headplanes, relu; norm_layer)) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end @@ -48,9 +51,9 @@ end # n: number of repeats # r: reduction factor - only used for `mbconv` # a: activation function +# Data is organised as (f, k, c, (e,) s, n, (r,) a) const MNASNET_CONFIGS = Dict(:B1 => (32, [ - # f, k, c, (e,) s, n, (r,) a (dwsep_conv_bn, 3, 16, 1, 1, relu), (mbconv, 3, 24, 3, 2, 3, nothing, relu), (mbconv, 5, 40, 3, 2, 3, nothing, relu), @@ -68,19 +71,16 @@ const MNASNET_CONFIGS = Dict(:B1 => (32, (mbconv, 3, 112, 6, 1, 2, 4, relu), (mbconv, 5, 160, 6, 2, 3, 4, relu), (mbconv, 3, 320, 6, 1, 1, nothing, relu), - ]) - # TODO small doesn't work yet - need to fix squeeze and excite - # channel calculations somehow - # :small => (8, - # [ - # (dwsep_conv_bn, 3, 8, 1, 1, relu), - # (mbconv, 3, 16, 3, 2, 1, nothing, relu), - # (mbconv, 3, 16, 6, 2, 2, nothing, relu), - # (mbconv, 5, 32, 6, 2, 4, 4, relu), - # (mbconv, 3, 32, 6, 1, 2, 3, relu), - # (mbconv, 5, 88, 6, 2, 3, 3, relu), - # (mbconv, 3, 144, 6, 1, 1, nothing, relu),]), - ) + ]), + :small => (8, + [ + (dwsep_conv_bn, 3, 8, 1, 1, relu), + (mbconv, 3, 16, 3, 2, 1, nothing, relu), + (mbconv, 3, 16, 6, 2, 2, nothing, relu), + (mbconv, 5, 32, 6, 2, 4, 4, relu), + (mbconv, 3, 32, 6, 1, 3, 4, relu), + (mbconv, 5, 88, 6, 2, 3, 4, relu), + (mbconv, 3, 144, 6, 1, 1, nothing, relu)])) """ MNASNet(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false, diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index bddef7aae..220364aaf 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -34,7 +34,8 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) # building inverted residual blocks get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult, - se_from_explanes = true) + se_from_explanes = true, + se_round_fn = _round_channels) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers explanes = _round_channels(configs[end][3] * width_mult, 8) @@ -55,8 +56,8 @@ end # n: number of repeats # r: squeeze and excite reduction factor # a: activation function +# Data is organised as (f, k, c, e, s, n, r, a) const MOBILENETV3_CONFIGS = Dict(:small => [ - # f, k, c, e, s, n, r, a (mbconv, 3, 16, 1, 2, 1, 4, relu), (mbconv, 3, 24, 4.5, 2, 1, nothing, relu), (mbconv, 3, 24, 3.67, 1, 1, nothing, relu), @@ -66,7 +67,6 @@ const MOBILENETV3_CONFIGS = Dict(:small => [ (mbconv, 5, 96, 6, 1, 3, 4, hardswish), ], :large => [ - # f, k, c, e, s, n, r, a (mbconv, 3, 16, 1, 1, 1, nothing, relu), (mbconv, 3, 24, 4, 2, 1, nothing, relu), (mbconv, 3, 24, 3, 1, 1, nothing, relu), diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f555877b7..e49611280 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -33,7 +33,7 @@ Create a convolution + batch normalization pair with activation. function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, - momentum::Union{Nothing, Number} = nothing, bias = !use_norm, kwargs...) + bias = !use_norm, kwargs...) # no normalization layer if !use_norm if preact || revnorm @@ -59,11 +59,6 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, activations = (conv = activation, bn = identity) end end - # handle momentum for BatchNorm - if !isnothing(momentum) - @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" - norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) - end # layers layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), norm_layer(normplanes, activations.bn; ϵ = eps)] diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index e740480f8..e37a98406 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -30,7 +30,6 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). `use_norm[1] == true`. - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - - `dilation`: dilation of the first convolution kernel - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, @@ -46,35 +45,47 @@ end # TODO add support for stochastic depth to mbconv and fused_mbconv """ - mbconv(kernel_size, inplanes::Integer, explanes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing) + mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Real} = nothing, + se_round_fn = x -> round(Int, x), norm_layer = BatchNorm, kwargs...) -Create a basic inverted residual block for MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). +Create a basic inverted residual block for MobileNet and Efficient variants. +This is a sequence of layers: + + - a 1x1 convolution from `inplanes => explanes` followed by a (batch) normalisation layer + + - `activation` if `inplanes != explanes` + - a `kernel_size` depthwise separable convolution from `explanes => explanes` + - a (batch) normalisation layer + - a squeeze-and-excitation block (if `reduction != nothing`) from + `explanes => se_round_fn(explanes / reduction)` and back to `explanes` + - a 1x1 convolution from `explanes => outplanes` + - a (batch) normalisation layer + `activation` + +First introduced in the MobileNetv2 paper. +(See Fig. 3 in [reference](https://arxiv.org/abs/1801.04381v4).) # Arguments - `kernel_size`: kernel size of the convolutional layers - `inplanes`: number of input feature maps - - `explanes`: The number of feature maps in the hidden layer + - `explanes`: The number of expanded feature maps. This is the number of feature maps + after the first 1x1 convolution. - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)) + - `se_round_fn`: The function to round the number of reduced feature maps + in the squeeze and excite layer + - `norm_layer`: The normalization layer to use """ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, - dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, - norm_layer = BatchNorm, momentum::Union{Nothing, Number} = nothing, - se_from_explanes::Bool = false, divisor::Integer = 8, no_skip::Bool = false) + reduction::Union{Nothing, Real} = nothing, + se_round_fn = x -> round(Int, x), norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" - # handle momentum for BatchNorm - if !isnothing(momentum) - @assert norm_layer==BatchNorm "`momentum` is only supported for `BatchNorm`" - norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum, kwargs...) - end layers = [] # expand if inplanes != explanes @@ -84,23 +95,48 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, # depthwise append!(layers, conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, dilation, pad = SamePad(), groups = explanes)) + stride, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) - squeeze_planes = _round_channels((se_from_explanes ? explanes : inplanes) ÷ - reduction, divisor) push!(layers, - squeeze_excite(explanes, squeeze_planes; activation, gate_activation = hardσ)) + squeeze_excite(explanes; round_fn = se_round_fn, reduction, + activation, gate_activation = hardσ)) end # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) - use_skip = stride == 1 && inplanes == outplanes && !no_skip - return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) + return Chain(layers...) end +""" + fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) + +Create a fused inverted residual block. + +This is a sequence of layers: + + - a `kernel_size` depthwise separable convolution from `explanes => explanes` + - a (batch) normalisation layer + - a 1x1 convolution from `explanes => outplanes` followed by a (batch) normalisation + layer + `activation` if `inplanes != explanes` + +Originally introduced by Google in [EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML](https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html). +Later used in the EfficientNetv2 paper. + +# Arguments + + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps + - `explanes`: The number of expanded feature maps + - `outplanes`: The number of output feature maps + - `activation`: The activation function for the first two convolution layer + - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 + - `norm_layer`: The normalization layer to use +""" function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, norm_layer = BatchNorm, no_skip::Bool = false) + stride::Integer, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `fused_mbconv`" layers = [] # fused expand @@ -112,6 +148,5 @@ function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, # project append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) end - use_skip = stride == 1 && inplanes == outplanes && !no_skip - return use_skip ? SkipConnection(Chain(layers...), +) : Chain(layers...) + return Chain(layers...) end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index c9f4873e9..5757c0593 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -36,10 +36,9 @@ function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; gate_activation] return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end -function squeeze_excite(inplanes::Integer; reduction::Integer = 16, - rd_divisor::Integer = 8, kwargs...) - return squeeze_excite(inplanes, _round_channels(inplanes ÷ reduction, rd_divisor, 0); - kwargs...) +function squeeze_excite(inplanes::Integer; reduction::Real = 16, + round_fn = _round_channels, kwargs...) + return squeeze_excite(inplanes, round_fn(inplanes / reduction); kwargs...) end """ diff --git a/src/utilities.jl b/src/utilities.jl index 11deb4373..13b8ec385 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -3,7 +3,7 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants -function _round_channels(channels::Number, divisor::Integer, min_value::Integer = divisor) +function _round_channels(channels::Number, divisor::Integer = 8, min_value::Integer = 0) new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) # Make sure that round down does not go down by more than 10% return new_channels < 0.9 * channels ? new_channels + divisor : new_channels diff --git a/test/convnets.jl b/test/convnets.jl index 1a2b0562e..0c796e24c 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -277,55 +277,53 @@ end end end -@testset "MobileNet" verbose = true begin - for width_mult in [0.5, 0.75, 1, 1.3] - @testset "MobileNetv1" begin - m = MobileNetv1(width_mult) +@testset "MobileNets (width = $width_mult)" for width_mult in [0.5, 0.75, 1, 1.3] + @testset "MobileNetv1" begin + m = MobileNetv1(width_mult) + @test size(m(x_224)) == (1000, 1) + if (MobileNetv1, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv1(pretrain = true)) + else + @test_throws ArgumentError MobileNetv1(pretrain = true) + end + @test gradtest(m, x_224) + end + _gc() + @testset "MobileNetv2" begin + m = MobileNetv2(width_mult) + @test size(m(x_224)) == (1000, 1) + if (MobileNetv2, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv2(pretrain = true)) + else + @test_throws ArgumentError MobileNetv2(pretrain = true) + end + @test gradtest(m, x_224) + end + _gc() + @testset "MobileNetv3" verbose = true begin + @testset for config in [:small, :large] + m = MobileNetv3(config; width_mult) @test size(m(x_224)) == (1000, 1) - if (MobileNetv1, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv1(pretrain = true)) + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) else - @test_throws ArgumentError MobileNetv1(pretrain = true) + @test_throws ArgumentError MobileNetv3(config; pretrain = true) end @test gradtest(m, x_224) + _gc() end - _gc() - @testset "MobileNetv2" begin - m = MobileNetv2(width_mult) + end + @testset "MNASNet" verbose = true begin + @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) @test size(m(x_224)) == (1000, 1) - if (MobileNetv2, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv2(pretrain = true)) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) else - @test_throws ArgumentError MobileNetv2(pretrain = true) + @test_throws ArgumentError MNASNet(config; pretrain = true) end @test gradtest(m, x_224) - end - _gc() - @testset "MobileNetv3" verbose = true begin - @testset for config in [:small, :large] - m = MobileNetv3(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) - else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) - end - @test gradtest(m, x_224) - _gc() - end - end - @testset "MNASNet" verbose = true begin - @testset for config in [:A1, :B1] - m = MNASNet(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MNASNet, config, width_mult) in PRETRAINED_MODELS - @test acctest(MNASNet(config; pretrain = true)) - else - @test_throws ArgumentError MNASNet(config; pretrain = true) - end - @test gradtest(m, x_224) - _gc() - end + _gc() end end end From f76fadbf8719a58d06ff0deac9c68959fb81b252 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 23 Aug 2022 22:47:08 +0530 Subject: [PATCH 33/34] Minor refactor of `cnn_stages` And no, final cleanup was not final. 1. Fix `width_mult` calculations 2. Hopefully all the parameters line up now for all widths of the MobileNets 3. `MNASNet` wasn't `@functor`ed 4. Random docstring link fixes --- src/convnets/builders/core.jl | 24 ++++++------------------ src/convnets/builders/mbconv.jl | 9 ++++++--- src/convnets/efficientnets/core.jl | 7 ++++--- src/convnets/mobilenets/mnasnet.jl | 11 ++++++++--- src/convnets/mobilenets/mobilenetv1.jl | 24 +++++++++++------------- src/convnets/mobilenets/mobilenetv2.jl | 2 +- src/convnets/mobilenets/mobilenetv3.jl | 9 ++++----- src/convnets/resnets/resnet.jl | 4 ++-- src/convnets/resnets/resnext.jl | 2 +- src/convnets/resnets/seresnet.jl | 4 ++-- src/layers/drop.jl | 2 +- src/layers/selayers.jl | 9 ++++----- 12 files changed, 50 insertions(+), 57 deletions(-) diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl index e02092eca..f97f92ff9 100644 --- a/src/convnets/builders/core.jl +++ b/src/convnets/builders/core.jl @@ -1,11 +1,15 @@ -# TODO potentially refactor other CNNs to use this -function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, + connection = nothing) # Construct each stage stages = [] for (stage_idx, nblocks) in enumerate(block_repeats) # Construct the blocks for each stage blocks = map(1:nblocks) do block_idx branches = get_layers(stage_idx, block_idx) + if isnothing(connection) + @assert length(branches)==1 "get_layers should return a single branch for + each block if no connection is specified" + end return length(branches) == 1 ? only(branches) : Parallel(connection, branches...) end @@ -13,19 +17,3 @@ function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connec end return stages end - -function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}) - # Construct each stage - stages = [] - for (stage_idx, nblocks) in enumerate(block_repeats) - # Construct the blocks for each stage - blocks = map(1:nblocks) do block_idx - branches = get_layers(stage_idx, block_idx) - @assert length(branches)==1 "get_layers should return a single branch for each - block if no connection is specified" - return only(branches) - end - push!(stages, Chain(blocks...)) - end - return stages -end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index 079f05011..66235a30f 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -2,7 +2,9 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, width_mult::Real; norm_layer = BatchNorm, kwargs...) block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] outplanes = floor(Int, outplanes * width_mult) - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + if stage_idx != 1 + inplanes = floor(Int, block_configs[stage_idx - 1][3] * width_mult) + end function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes stride = block_idx == 1 ? stride : 1 @@ -23,8 +25,9 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, if !isnothing(reduction) reduction = !se_from_explanes ? reduction * expansion : reduction end - inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] - inplanes = _round_channels(inplanes * width_mult, divisor) + if stage_idx != 1 + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor) + end outplanes = _round_channels(outplanes * width_mult, divisor) function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 947080481..d853e0c1a 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -5,16 +5,17 @@ function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model + inplanes = _round_channels(inplanes * scalings[1]) append!(layers, - conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8), - swish; norm_layer, stride = 2, pad = SamePad())) + conv_norm((3, 3), inchannels, inplanes, swish; norm_layer, stride = 2, + pad = SamePad())) # building inverted residual blocks get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; scalings, norm_layer) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1]), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index 1445e1a4b..9db57e998 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -25,7 +25,7 @@ function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = _MNASNET_BN_MOMENTUM, kwargs...) # building first layer - inplanes = _round_channels(inplanes * width_mult, 8) + inplanes = _round_channels(inplanes * width_mult) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes, relu; stride = 2, pad = 1, @@ -35,8 +35,8 @@ function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, norm_layer) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers - outplanes = _round_channels(block_configs[end][3] * width_mult, 8) - headplanes = _round_channels(max_width * max(1, width_mult), 8) + outplanes = _round_channels(block_configs[end][3] * width_mult) + headplanes = _round_channels(max_width * max(1, width_mult)) append!(layers, conv_norm((1, 1), outplanes, headplanes, relu; norm_layer)) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) @@ -119,3 +119,8 @@ function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, end return MNASNet(layers) end + +(m::MNASNet)(x) = m.layers(x) + +backbone(m::MNASNet) = m.layers[1] +classifier(m::MNASNet) = m.layers[2] diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 5fafbaf8a..36b9ba1bb 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -23,16 +23,18 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). """ function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1, activation = relu, dropout_rate = nothing, - inchannels::Integer = 3, nclasses::Integer = 1000) + inplanes::Integer = 32, inchannels::Integer = 3, + nclasses::Integer = 1000) layers = [] # stem of the model + inplanes = floor(Int, inplanes * width_mult) append!(layers, - conv_norm((3, 3), inchannels, config[1][3], activation; stride = 2, pad = 1)) + conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1)) # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(config, config[1][3]; width_mult) + get_layers, block_repeats = mbconv_stack_builder(config, inplanes; width_mult) append!(layers, cnn_stages(get_layers, block_repeats)) - return Chain(Chain(layers...), - create_classifier(config[end][3], nclasses; dropout_rate)) + outplanes = floor(Int, config[end][3] * width_mult) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 @@ -45,14 +47,10 @@ end const MOBILENETV1_CONFIGS = [ # f, k, c, s, n, a (dwsep_conv_bn, 3, 64, 1, 1, relu6), - (dwsep_conv_bn, 3, 128, 2, 1, relu6), - (dwsep_conv_bn, 3, 128, 1, 1, relu6), - (dwsep_conv_bn, 3, 256, 2, 1, relu6), - (dwsep_conv_bn, 3, 256, 1, 1, relu6), - (dwsep_conv_bn, 3, 512, 2, 1, relu6), - (dwsep_conv_bn, 3, 512, 1, 5, relu6), - (dwsep_conv_bn, 3, 1024, 2, 1, relu6), - (dwsep_conv_bn, 3, 1024, 1, 1, relu6), + (dwsep_conv_bn, 3, 128, 2, 2, relu6), + (dwsep_conv_bn, 3, 256, 2, 2, relu6), + (dwsep_conv_bn, 3, 512, 2, 6, relu6), + (dwsep_conv_bn, 3, 1024, 2, 2, relu6), ] """ diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index fdc536ca7..404650cc5 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -37,7 +37,7 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = divisor) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers - outplanes = _round_channels(block_configs[end][3], divisor) + outplanes = _round_channels(block_configs[end][3] * width_mult, divisor) headplanes = _round_channels(max_width * max(1, width_mult), divisor) append!(layers, conv_norm((1, 1), outplanes, headplanes, relu6)) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 220364aaf..2614c7c2f 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -28,7 +28,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, max_width::Integer = 1024, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer - inplanes = _round_channels(16 * width_mult, 8) + inplanes = _round_channels(16 * width_mult) layers = [] append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) @@ -38,12 +38,11 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, se_round_fn = _round_channels) append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers - explanes = _round_channels(configs[end][3] * width_mult, 8) - midplanes = _round_channels(explanes * configs[end][4], 8) - headplanes = _round_channels(max_width * width_mult, 8) + explanes = _round_channels(configs[end][3] * width_mult) + midplanes = _round_channels(explanes * configs[end][4]) append!(layers, conv_norm((1, 1), explanes, midplanes, hardswish)) return Chain(Chain(layers...), - create_classifier(midplanes, headplanes, nclasses, + create_classifier(midplanes, max_width, nclasses, (hardswish, identity); dropout_rate)) end diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index f935c3b93..b65b71072 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -2,7 +2,7 @@ ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a ResNet model with the specified depth. -((reference)[https://arxiv.org/abs/1512.03385]) +([reference](https://arxiv.org/abs/1512.03385)) # Arguments @@ -39,7 +39,7 @@ classifier(m::ResNet) = m.layers[2] Creates a Wide ResNet model with the specified depth. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same. -((reference)[https://arxiv.org/abs/1605.07146]) +([reference](https://arxiv.org/abs/1605.07146)) # Arguments diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 20fc912a2..2a8fbd561 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -3,7 +3,7 @@ base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a ResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/abs/1611.05431]) +([reference](https://arxiv.org/abs/1611.05431)) # Arguments diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index ff39921b0..487665518 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -2,7 +2,7 @@ SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a SEResNet model with the specified depth. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) +([reference](https://arxiv.org/pdf/1709.01507.pdf)) # Arguments @@ -43,7 +43,7 @@ classifier(m::SEResNet) = m.layers[2] base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a SEResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) +([reference](https://arxiv.org/pdf/1709.01507.pdf)) # Arguments diff --git a/src/layers/drop.jl b/src/layers/drop.jl index edff55234..8a82d5b16 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -70,7 +70,7 @@ At test time, the `DropBlock` layer is equivalent to `identity`. probability across blocks. If this is not desired, then a lower base survival probability is recommended. -((reference)[https://arxiv.org/abs/1810.12890]) +([reference](https://arxiv.org/abs/1810.12890)) # Arguments diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 5757c0593..044d61dbf 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -3,8 +3,9 @@ norm_layer = planes -> identity, activation = relu, gate_activation = sigmoid) - squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, - activation = relu, gate_activation = sigmoid, norm_layer = identity) + squeeze_excite(inplanes::Integer; reduction::Real = 16, + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets. @@ -14,9 +15,7 @@ Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE- - `squeeze_planes`: The number of feature maps in the intermediate layers. Alternatively, specify the keyword arguments `reduction` and `rd_divisior`, which determine the number of feature maps in the intermediate layers from the number of input feature maps as: - `squeeze_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)`. - (See [`_round_channels`](#) for details. The default values are `reduction = 16` and - `rd_divisor = 8`.) + `squeeze_planes = _round_channels(inplanes ÷ reduction)`. (See [`_round_channels`](#) for details.) - `activation`: The activation function for the first convolution layer - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers From 992f6a6d5367f85755c96179e77dfc4c30022e52 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 24 Aug 2022 00:34:43 +0530 Subject: [PATCH 34/34] `_round_channels` all the way And `max_width` doesn't get width scaled --- src/convnets/builders/mbconv.jl | 4 ++-- src/convnets/mobilenets/mnasnet.jl | 7 +++---- src/convnets/mobilenets/mobilenetv1.jl | 4 ++-- src/convnets/mobilenets/mobilenetv2.jl | 2 +- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index 66235a30f..31a936add 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -1,9 +1,9 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, width_mult::Real; norm_layer = BatchNorm, kwargs...) block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] - outplanes = floor(Int, outplanes * width_mult) + outplanes = _round_channels(outplanes * width_mult) if stage_idx != 1 - inplanes = floor(Int, block_configs[stage_idx - 1][3] * width_mult) + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult) end function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index 9db57e998..2f6db2acf 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -36,10 +36,9 @@ function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, cnn_stages(get_layers, block_repeats, +)) # building last layers outplanes = _round_channels(block_configs[end][3] * width_mult) - headplanes = _round_channels(max_width * max(1, width_mult)) append!(layers, - conv_norm((1, 1), outplanes, headplanes, relu; norm_layer)) - return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) + conv_norm((1, 1), outplanes, max_width, relu; norm_layer)) + return Chain(Chain(layers...), create_classifier(max_width, nclasses; dropout_rate)) end # Layer configurations for MNasNet @@ -115,7 +114,7 @@ function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inplanes, block_configs = MNASNET_CONFIGS[config] layers = mnasnet(block_configs; width_mult, inplanes, inchannels, nclasses) if pretrain - load_pretrained!(layers, "mnasnet$(width_mult)") + loadpretrain!(layers, "mnasnet$(width_mult)") end return MNASNet(layers) end diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 36b9ba1bb..24240d0c0 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -27,13 +27,13 @@ function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1, nclasses::Integer = 1000) layers = [] # stem of the model - inplanes = floor(Int, inplanes * width_mult) + inplanes = _round_channels(inplanes * width_mult) append!(layers, conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1)) # building inverted residual blocks get_layers, block_repeats = mbconv_stack_builder(config, inplanes; width_mult) append!(layers, cnn_stages(get_layers, block_repeats)) - outplanes = floor(Int, config[end][3] * width_mult) + outplanes = _round_channels(config[end][3] * width_mult) return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 404650cc5..f3e26862c 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -44,7 +44,7 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = end # Layer configurations for MobileNetv2 -# f: block function - we use `dwsep_conv_bn` for the first block and `mbconv` for the rest +# f: block function - we use `mbconv` for all blocks # k: kernel size # c: output channels # e: expansion factor