Skip to content

Unify lower level API for EfficientNet and MobileNet model families #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@ using .Layers

# CNN models
## Builders
include("convnets/builders/core.jl")
include("convnets/builders/irmodel.jl")
include("convnets/builders/mbconv.jl")
include("convnets/builders/resblocks.jl")
include("convnets/builders/resnet.jl")
include("convnets/builders/stages.jl")
## AlexNet and VGG
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
## ResNets
include("convnets/resnets/core.jl")
include("convnets/resnets/res2net.jl")
include("convnets/resnets/resnet.jl")
include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
include("convnets/resnets/res2net.jl")
## Inceptions
include("convnets/inceptions/googlenet.jl")
include("convnets/inceptions/inceptionv3.jl")
include("convnets/inceptions/inceptionv4.jl")
include("convnets/inceptions/inceptionresnetv2.jl")
include("convnets/inceptions/xception.jl")
## EfficientNets
include("convnets/efficientnets/core.jl")
include("convnets/efficientnets/efficientnet.jl")
include("convnets/efficientnets/efficientnetv2.jl")
## MobileNets
Expand Down
9 changes: 5 additions & 4 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
alexnet(; dropout_rate = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)

Create an AlexNet model
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).

# Arguments

- `dropout_rate`: dropout rate for the classifier
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
"""
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
function alexnet(; dropout_rate = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
MaxPool((3, 3); stride = 2),
Conv((5, 5), 64 => 192, relu; pad = 2),
Expand All @@ -19,9 +20,9 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
Conv((3, 3), 256 => 256, relu; pad = 1),
MaxPool((3, 3); stride = 2))
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
Dropout(0.5),
Dropout(dropout_rate),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dropout(dropout_rate),
Dense(4096, 4096, relu),
Dense(4096, nclasses))
return Chain(backbone, classifier)
Expand Down
41 changes: 41 additions & 0 deletions src/convnets/builders/irmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple};
inplanes::Integer = 32, connection = +, activation = relu,
norm_layer = BatchNorm, divisor::Integer = 8,
tail_conv::Bool = true, expanded_classifier::Bool = false,
headplanes::Integer, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...)
width_mult, _ = scalings
# building first layer
inplanes = _round_channels(inplanes * width_mult, divisor)
layers = []
append!(layers,
conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1,
norm_layer))
# building inverted residual blocks
get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings;
norm_layer, divisor, kwargs...)
append!(layers, cnn_stages(get_layers, block_repeats, connection))
# building last layers
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)
if tail_conv
# special case, supported fully only for MobileNetv3
if expanded_classifier
midplanes = _round_channels(outplanes * block_configs[end][4], divisor)
append!(layers,
conv_norm((1, 1), outplanes, midplanes, activation; norm_layer))
classifier = create_classifier(midplanes, headplanes, nclasses,
(hardswish, identity); dropout_rate)
else
append!(layers,
conv_norm((1, 1), outplanes, headplanes, activation; norm_layer))
classifier = create_classifier(headplanes, nclasses; dropout_rate)
end
else
classifier = create_classifier(outplanes, nclasses; dropout_rate)
end
return Chain(Chain(layers...), classifier)
end

function irmodelbuilder(width_mult::Real, block_configs::AbstractVector{<:Tuple}; kwargs...)
return irmodelbuilder((width_mult, 1), block_configs; kwargs...)
end
88 changes: 28 additions & 60 deletions src/convnets/builders/mbconv.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
# TODO - potentially make these builders more flexible to specify stuff like
# activation functions and reductions that don't change over the stages

function dwsepconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
stage_idx::Integer, scalings::NTuple{2, Real};
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
outplanes = _round_channels(outplanes * width_mult)
outplanes = _round_channels(outplanes * width_mult, divisor)
if stage_idx != 1
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult)
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
end
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
Expand All @@ -12,13 +17,14 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
stride, pad = SamePad(), norm_layer, kwargs...)...)
return (block,)
end
return get_layers, nrepeats
return get_layers, ceil(Int, nrepeats * depth_mult)
end
_get_builder(::typeof(dwsep_conv_norm)) = dwsepconv_builder

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
scalings::NTuple{2, Real}; norm_layer = BatchNorm,
divisor::Integer = 8, se_from_explanes::Bool = false,
kwargs...)
function mbconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
stage_idx::Integer, scalings::NTuple{2, Real};
norm_layer = BatchNorm, divisor::Integer = 8,
se_from_explanes::Bool = false, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
Expand All @@ -39,69 +45,31 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
end
return get_layers, ceil(Int, nrepeats * depth_mult)
end
_get_builder(::typeof(mbconv)) = mbconv_builder

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1);
norm_layer, kwargs...)
end

function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer;
norm_layer = BatchNorm, kwargs...)
function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
stage_idx::Integer, scalings::NTuple{2, Real};
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
outplanes = _round_channels(outplanes * width_mult, divisor)
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
explanes = _round_channels(inplanes * expansion, 8)
explanes = _round_channels(inplanes * expansion, divisor)
stride = block_idx == 1 ? stride : 1
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
norm_layer, stride, kwargs...)
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
end
return get_layers, nrepeats
end

# TODO - these builders need to be more flexible to potentially specify stuff like
# activation functions and reductions that don't change
function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
@assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument"
return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
end

function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
if isnothing(scalings)
return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
elseif isnothing(width_mult)
return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer,
kwargs...)
else
throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified"))
end
end

function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer)
@assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument."
@assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument"
return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer)
return get_layers, ceil(Int, nrepeats * depth_mult)
end
_get_builder(::typeof(fused_mbconv)) = fused_mbconv_builder

function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing,
norm_layer = BatchNorm, kwargs...)
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings,
width_mult, norm_layer, kwargs...)
function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
scalings::NTuple{2, Real}; kwargs...)
builders = _get_builder.(first.(block_configs))
bxs = [builders[idx](block_configs, inplanes, idx, scalings; kwargs...)
for idx in eachindex(block_configs)]
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
end
9 changes: 9 additions & 0 deletions src/convnets/builders/resnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function resnetbuilder(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
connection, classifier_fn)
# Build stages of the ResNet
stage_blocks = cnn_stages(get_layers, block_repeats, connection)
backbone = Chain(stem, stage_blocks...)
# Add classifier to the backbone
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
return Chain(backbone, classifier_fn(nfeaturemaps))
end
File renamed without changes.
26 changes: 21 additions & 5 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
"`planes` should have exactly one value for each block"
downsample_layers = []
push!(downsample_layers,
Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4,
Chain(conv_norm((4, 4), inchannels, planes[1]; stride = 4,
norm_layer = ChannelLayerNorm)...))
for m in 1:(length(depths) - 1)
push!(downsample_layers,
Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2,
Chain(conv_norm((2, 2), planes[m], planes[m + 1]; stride = 2,
norm_layer = ChannelLayerNorm, revnorm = true)...))
end
stages = []
Expand All @@ -68,6 +68,12 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
return Chain(Chain(backbone...), classifier)
end

function convnext(config::Symbol; drop_path_rate = 0.0, layerscale_init = 1.0f-6,
inchannels::Integer = 3, nclasses::Integer = 1000)
return convnext(CONVNEXT_CONFIGS[config]...; drop_path_rate, layerscale_init,
inchannels, nclasses)
end

# Configurations for ConvNeXt models
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
Expand All @@ -76,27 +82,37 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))

"""
ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3,
nclasses::Integer = 1000)

Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments

- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
- `inchannels`: number of input channels
- `nclasses`: number of output classes

!!! warning

`ConvNeXt` does not currently support pretrained weights.

See also [`Metalhead.convnext`](#).
"""
struct ConvNeXt
layers::Any
end
@functor ConvNeXt

function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
function ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3,
nclasses::Integer = 1000)
_checkconfig(config, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses)
layers = convnext(config; inchannels, nclasses)
if pretrain
layers = loadpretrain!(layers, "convnext_$config")
end
return ConvNeXt(layers)
end

Expand Down
21 changes: 0 additions & 21 deletions src/convnets/efficientnets/core.jl

This file was deleted.

16 changes: 12 additions & 4 deletions src/convnets/efficientnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6)))

function efficientnet(config::Symbol; norm_layer = BatchNorm,
dropout_rate = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000)
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
return irmodelbuilder(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32,
norm_layer, activation = swish,
headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4,
dropout_rate, inchannels, nclasses)
end

"""
EfficientNet(config::Symbol; pretrain::Bool = false)

Expand All @@ -50,10 +61,7 @@ end

function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings,
inchannels, nclasses)
layers = efficientnet(config; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("efficientnet-", config))
end
Expand Down
14 changes: 10 additions & 4 deletions src/convnets/efficientnets/efficientnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish),
(mbconv, 3, 512, 6, 2, 32, 4, swish),
(mbconv, 3, 768, 6, 1, 8, 4, swish)])

function efficientnetv2(config::Symbol; norm_layer = BatchNorm, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, keys(EFFNETV2_CONFIGS))
block_configs = EFFNETV2_CONFIGS[config]
return irmodelbuilder((1, 1), block_configs; activation = swish, norm_layer,
inplanes = block_configs[1][3], headplanes = 1280,
dropout_rate, inchannels, nclasses)
end

"""
EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1,
inchannels::Integer = 3, nclasses::Integer = 1000)
Expand All @@ -57,10 +66,7 @@ end

function EfficientNetv2(config::Symbol; pretrain::Bool = false,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS))))
block_configs = EFFNETV2_CONFIGS[config]
layers = efficientnet(block_configs; inplanes = block_configs[1][3],
headplanes = 1280, inchannels, nclasses)
layers = efficientnetv2(config; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("efficientnetv2-", config))
end
Expand Down
Loading