diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5e14d6c49..f796a66d7 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,10 +27,10 @@ jobs: - x64 suite: - '["AlexNet", "VGG"]' - - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - - '["EfficientNet"]' + - '["GoogLeNet", "SqueezeNet", "MobileNets"]' + - '"EfficientNet"' - 'r"/*/ResNet*"' - - '[r"ResNeXt", r"SEResNet"]' + - 'r"/*/SEResNet*"' - '[r"Res2Net", r"Res2NeXt"]' - '"Inception"' - '"DenseNet"' diff --git a/Project.toml b/Project.toml index 691003944..9664dc79d 100644 --- a/Project.toml +++ b/Project.toml @@ -23,10 +23,10 @@ Flux = "0.13" Functors = "0.2, 0.3" CUDA = "3" ChainRulesCore = "1" -PartialFunctions = "1" MLUtils = "0.2.10" NNlib = "0.8" NNlibCUDA = "0.2" +PartialFunctions = "1" julia = "1.6" [publish] diff --git a/src/Metalhead.jl b/src/Metalhead.jl index aa236454c..f9be49db1 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -19,6 +19,11 @@ include("layers/Layers.jl") using .Layers # CNN models +## Builders +include("convnets/builders/core.jl") +include("convnets/builders/mbconv.jl") +include("convnets/builders/resblocks.jl") +## AlexNet and VGG include("convnets/alexnet.jl") include("convnets/vgg.jl") ## ResNets @@ -28,19 +33,23 @@ include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") include("convnets/resnets/res2net.jl") ## Inceptions -include("convnets/inception/googlenet.jl") -include("convnets/inception/inceptionv3.jl") -include("convnets/inception/inceptionv4.jl") -include("convnets/inception/inceptionresnetv2.jl") -include("convnets/inception/xception.jl") +include("convnets/inceptions/googlenet.jl") +include("convnets/inceptions/inceptionv3.jl") +include("convnets/inceptions/inceptionv4.jl") +include("convnets/inceptions/inceptionresnetv2.jl") +include("convnets/inceptions/xception.jl") +## EfficientNets +include("convnets/efficientnets/core.jl") +include("convnets/efficientnets/efficientnet.jl") +include("convnets/efficientnets/efficientnetv2.jl") ## MobileNets -include("convnets/mobilenet/mobilenetv1.jl") -include("convnets/mobilenet/mobilenetv2.jl") -include("convnets/mobilenet/mobilenetv3.jl") +include("convnets/mobilenets/mobilenetv1.jl") +include("convnets/mobilenets/mobilenetv2.jl") +include("convnets/mobilenets/mobilenetv3.jl") +include("convnets/mobilenets/mnasnet.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") -include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") @@ -61,13 +70,16 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, - SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, + SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, + EfficientNet, EfficientNetv2, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, - :Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, - :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, +for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, + :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, + :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, + :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, + :EfficientNet, :EfficientNetv2, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl new file mode 100644 index 000000000..f97f92ff9 --- /dev/null +++ b/src/convnets/builders/core.jl @@ -0,0 +1,19 @@ +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, + connection = nothing) + # Construct each stage + stages = [] + for (stage_idx, nblocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + if isnothing(connection) + @assert length(branches)==1 "get_layers should return a single branch for + each block if no connection is specified" + end + return length(branches) == 1 ? only(branches) : + Parallel(connection, branches...) + end + push!(stages, Chain(blocks...)) + end + return stages +end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl new file mode 100644 index 000000000..31a936add --- /dev/null +++ b/src/convnets/builders/mbconv.jl @@ -0,0 +1,107 @@ +function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + width_mult::Real; norm_layer = BatchNorm, kwargs...) + block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] + outplanes = _round_channels(outplanes * width_mult) + if stage_idx != 1 + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult) + end + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + stride = block_idx == 1 ? stride : 1 + block = Chain(block_fn((k, k), inplanes, outplanes, activation; + stride, pad = SamePad(), norm_layer, kwargs...)...) + return (block,) + end + return get_layers, nrepeats +end + +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + scalings::NTuple{2, Real}; norm_layer = BatchNorm, + divisor::Integer = 8, se_from_explanes::Bool = false, + kwargs...) + width_mult, depth_mult = scalings + block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + # calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes + if !isnothing(reduction) + reduction = !se_from_explanes ? reduction * expansion : reduction + end + if stage_idx != 1 + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor) + end + outplanes = _round_channels(outplanes * width_mult, divisor) + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, divisor) + stride = block_idx == 1 ? stride : 1 + block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, + stride, reduction, kwargs...) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, ceil(Int, nrepeats * depth_mult) +end + +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, + width_mult::Real; norm_layer = BatchNorm, kwargs...) + return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1); + norm_layer, kwargs...) +end + +function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer; + norm_layer = BatchNorm, kwargs...) + block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + function get_layers(block_idx::Integer) + inplanes = block_idx == 1 ? inplanes : outplanes + explanes = _round_channels(inplanes * expansion, 8) + stride = block_idx == 1 ? stride : 1 + block = block_fn((k, k), inplanes, explanes, outplanes, activation; + norm_layer, stride, kwargs...) + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + end + return get_layers, nrepeats +end + +# TODO - these builders need to be more flexible to potentially specify stuff like +# activation functions and reductions that don't change +function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) + @assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument" + return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, + kwargs...) +end + +function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) + if isnothing(scalings) + return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, + kwargs...) + elseif isnothing(width_mult) + return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer, + kwargs...) + else + throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified")) + end +end + +function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, norm_layer) + @assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument." + @assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument" + return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer) +end + +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer; + scalings::Union{Nothing, NTuple{2, Real}} = nothing, + width_mult::Union{Nothing, Number} = nothing, + norm_layer = BatchNorm, kwargs...) + bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings, + width_mult, norm_layer, kwargs...) + for idx in eachindex(block_configs)] + return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) +end diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl new file mode 100644 index 000000000..8343bf811 --- /dev/null +++ b/src/convnets/builders/resblocks.jl @@ -0,0 +1,71 @@ +function basicblock_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 1, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + drop_block_rate = nothing, drop_path_rate = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # Also get `planes_vec` needed for block `inplanes` and `planes` calculations + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = stride != 1 || inplanes != planes * expansion ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = basicblock(inplanes, planes; stride, reduction_factor, activation, + norm_layer, revnorm, attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end + +function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + drop_block_rate = nothing, drop_path_rate = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = stride != 1 || inplanes != planes * expansion ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = bottleneck(inplanes, planes; stride, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 309989d2d..bc1a71a5f 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -1,5 +1,5 @@ """ - convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), + convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -13,20 +13,25 @@ Creates a ConvMixer model. - `kernel_size`: kernel size of the convolutional layers - `patch_size`: size of the patches - `activation`: activation function used after the convolutional layers - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ -function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), - patch_size::Dims{2} = (7, 7), activation = gelu, +function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), + patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) - stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, - stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; + layers = [] + # stem of the model + append!(layers, + conv_norm(patch_size, inchannels, planes, activation; preact = true, + stride = patch_size[1])) + # stages of the model + stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; preact = true, groups = planes, pad = SamePad())), +), conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] - return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses)) + append!(layers, stages) + return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate)) end const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), @@ -48,7 +53,7 @@ Creates a ConvMixer model. # Arguments - `config`: the size of the model, either `:base`, `:small` or `:large` - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of classes in the output """ struct ConvMixer diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 040a409ab..15271cfed 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -84,7 +84,7 @@ Creates a ConvNeXt model. # Arguments - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`. - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `nclasses`: number of output classes See also [`Metalhead.convnext`](#). diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index eb29c4966..a7c367c1c 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -10,12 +10,12 @@ Create a Densenet bottleneck layer - `outplanes`: number of output feature maps on bottleneck branch (and scaling factor for inner feature maps; see ref) """ -function dense_bottleneck(inplanes::Integer, outplanes::Integer) - inner_channels = 4 * outplanes - return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, +function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) + inner_channels = expansion * outplanes + return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, - bias = false, revnorm = true)...), + revnorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes::Integer, outplanes::Integer) - return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; revnorm = true)..., MeanPool((2, 2))) end @@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates) end """ - densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000) + densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a DenseNet model ([reference](https://arxiv.org/abs/1608.06993)). @@ -66,13 +67,14 @@ Create a DenseNet model - `growth_rates`: the growth rates of output feature maps within each [`dense_block`](#) (a vector of vectors) - `reduction`: the factor by which the number of feature maps is scaled across each transition + - `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes """ -function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3, - nclasses::Integer = 1000) +function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] append!(layers, - conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false)) + conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) @@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: inplanes = floor(Int, outplanes * reduction) end push!(layers, BatchNorm(outplanes, relu)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end """ @@ -100,9 +102,10 @@ Create a DenseNet model - `nclasses`: the number of output classes """ function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, - reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) + reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; - reduction, inchannels, nclasses) + reduction, dropout_rate, inchannels, nclasses) end const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], @@ -132,7 +135,8 @@ end function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32, reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(DENSENET_CONFIGS)) - layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses) + layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, + nclasses) if pretrain loadpretrain!(layers, string("densenet", config)) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl deleted file mode 100644 index 91986fb92..000000000 --- a/src/convnets/efficientnet.jl +++ /dev/null @@ -1,116 +0,0 @@ -""" - efficientnet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - -# Arguments - - - `scalings`: global width and depth scaling (given as a tuple) - - - `block_configs`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `n`: number of block repetitions (will be scaled by global depth scaling) - + `k`: kernel size - + `s`: kernel stride - + `e`: expansion ratio - + `i`: block input channels (will be scaled by global width scaling) - + `o`: block output channels (will be scaled by global width scaling) - - `inchannels`: number of input channels - - `nclasses`: number of output classes - - `max_width`: maximum number of output channels before the fully connected - classification blocks -""" -function efficientnet(scalings::NTuple{2, Real}, - block_configs::AbstractVector{NTuple{6, Int}}; - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) - wscale, dscale = scalings - scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) - scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - out_channels = _round_channels(scalew(32), 8) - stem = conv_norm((3, 3), inchannels, out_channels, swish; bias = false, stride = 2, - pad = SamePad()) - blocks = [] - for (n, k, s, e, i, o) in block_configs - in_channels = _round_channels(scalew(i), 8) - out_channels = _round_channels(scalew(o), 8) - repeats = scaled(n) - push!(blocks, - invertedresidual((k, k), in_channels, out_channels, swish; expansion = e, - stride = s, reduction = 4)) - for _ in 1:(repeats - 1) - push!(blocks, - invertedresidual((k, k), out_channels, out_channels, swish; expansion = e, - stride = 1, reduction = 4)) - end - end - head_out_channels = _round_channels(max_width, 8) - append!(blocks, - conv_norm((1, 1), out_channels, head_out_channels, swish; - bias = false, pad = SamePad())) - return Chain(Chain(stem..., blocks...), create_classifier(head_out_channels, nclasses)) -end - -# n: # of block repetitions -# k: kernel size k x k -# s: stride -# e: expantion ratio -# i: block input channels -# o: block output channels -const EFFICIENTNET_BLOCK_CONFIGS = [ - # (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), - (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320), -] - -# w: width scaling -# d: depth scaling -# r: image resolution -# Data is organised as (r, (w, d)) -const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6))) - -""" - EfficientNet(config::Symbol; pretrain::Bool = false) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). -See also [`efficientnet`](#). - -# Arguments - - - `config`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet -""" -struct EfficientNet - layers::Any -end -@functor EfficientNet - -function EfficientNet(config::Symbol; pretrain::Bool = false) - _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS) - if pretrain - loadpretrain!(model, string("efficientnet-", config)) - end - return model -end - -(m::EfficientNet)(x) = m.layers(x) - -backbone(m::EfficientNet) = m.layers[1] -classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl new file mode 100644 index 000000000..d853e0c1a --- /dev/null +++ b/src/convnets/efficientnets/core.jl @@ -0,0 +1,21 @@ +function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer, + scalings::NTuple{2, Real} = (1, 1), + headplanes::Integer = block_configs[end][3] * 4, + norm_layer = BatchNorm, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) + layers = [] + # stem of the model + inplanes = _round_channels(inplanes * scalings[1]) + append!(layers, + conv_norm((3, 3), inchannels, inplanes, swish; norm_layer, stride = 2, + pad = SamePad())) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; scalings, + norm_layer) + append!(layers, cnn_stages(get_layers, block_repeats, +)) + # building last layers + append!(layers, + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1]), + headplanes, swish; pad = SamePad())) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) +end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl new file mode 100644 index 000000000..5eb81b21d --- /dev/null +++ b/src/convnets/efficientnets/efficientnet.jl @@ -0,0 +1,66 @@ +# block configs for EfficientNet +# k: kernel size +# c: output channels +# e: expansion ratio +# s: stride +# n: number of repeats +# r: reduction ratio for squeeze-excite layer +# a: activation function +# Data is organised as (k, c, e, s, n, r, a) +const EFFICIENTNET_BLOCK_CONFIGS = [ + (mbconv, 3, 16, 1, 1, 1, 4, swish), + (mbconv, 3, 24, 6, 2, 2, 4, swish), + (mbconv, 5, 40, 6, 2, 2, 4, swish), + (mbconv, 3, 80, 6, 2, 3, 4, swish), + (mbconv, 5, 112, 6, 1, 3, 4, swish), + (mbconv, 5, 192, 6, 2, 4, 4, swish), + (mbconv, 3, 320, 6, 1, 1, 4, swish), +] + +# Data is organised as (r, (w, d)) +# r: image resolution +# w: width scaling +# d: depth scaling +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) + +""" + EfficientNet(config::Symbol; pretrain::Bool = false) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). +See also [`efficientnet`](#). + +# Arguments + + - `config`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet +""" +struct EfficientNet + layers::Any +end +@functor EfficientNet + +function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] + layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings, + inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnet-", config)) + end + return EfficientNet(layers) +end + +(m::EfficientNet)(x) = m.layers(x) + +backbone(m::EfficientNet) = m.layers[1] +classifier(m::EfficientNet) = m.layers[2] diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl new file mode 100644 index 000000000..ff64eea23 --- /dev/null +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -0,0 +1,73 @@ +# block configs for EfficientNetv2 +# k: kernel size +# c: output channels +# e: expansion ratio +# s: stride +# n: number of repeats +# r: reduction ratio for squeeze-excite layer - specified only for `mbconv` +# a: activation function +# Data organised as (f, k, c, e, s, n, (r,) a) for each stage +const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish), + (fused_mbconv, 3, 48, 4, 2, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 4, swish), + (mbconv, 3, 128, 4, 2, 6, 4, swish), + (mbconv, 3, 160, 6, 1, 9, 4, swish), + (mbconv, 3, 256, 6, 2, 15, 4, swish)], + :medium => [(fused_mbconv, 3, 24, 1, 1, 3, swish), + (fused_mbconv, 3, 48, 4, 2, 5, swish), + (fused_mbconv, 3, 80, 4, 2, 5, swish), + (mbconv, 3, 160, 4, 2, 7, 4, swish), + (mbconv, 3, 176, 6, 1, 14, 4, swish), + (mbconv, 3, 304, 6, 2, 18, 4, swish), + (mbconv, 3, 512, 6, 1, 5, 4, swish)], + :large => [(fused_mbconv, 3, 32, 1, 1, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 7, swish), + (fused_mbconv, 3, 96, 4, 2, 7, swish), + (mbconv, 3, 192, 4, 2, 10, 4, swish), + (mbconv, 3, 224, 6, 1, 19, 4, swish), + (mbconv, 3, 384, 6, 2, 25, 4, swish), + (mbconv, 3, 640, 6, 1, 7, 4, swish)], + :xlarge => [(fused_mbconv, 3, 32, 1, 1, 4, swish), + (fused_mbconv, 3, 64, 4, 2, 8, swish), + (fused_mbconv, 3, 96, 4, 2, 8, swish), + (mbconv, 3, 192, 4, 2, 16, 4, swish), + (mbconv, 3, 384, 6, 1, 24, 4, swish), + (mbconv, 3, 512, 6, 2, 32, 4, swish), + (mbconv, 3, 768, 6, 1, 8, 4, swish)]) + +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `width_mult`: Controls the number of output feature maps in each block (with 1 + being the default in the paper) + - `inchannels`: number of input channels + - `nclasses`: number of output classes +""" +struct EfficientNetv2 + layers::Any +end +@functor EfficientNetv2 + +function EfficientNetv2(config::Symbol; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) + block_configs = EFFNETV2_CONFIGS[config] + layers = efficientnet(block_configs; inplanes = block_configs[1][3], + headplanes = 1280, inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("efficientnetv2-", config)) + end + return EfficientNetv2(layers) +end + +(m::EfficientNetv2)(x) = m.layers(x) + +backbone(m::EfficientNetv2) = m.layers[1] +classifier(m::EfficientNetv2) = m.layers[2] diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl deleted file mode 100644 index cd4971742..000000000 --- a/src/convnets/inception/inceptionv4.jl +++ /dev/null @@ -1,156 +0,0 @@ -function mixed_3a() - return Parallel(cat_channels, - MaxPool((3, 3); stride = 2), - Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) -end - -function mixed_4a() - return Parallel(cat_channels, - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((3, 3), 64, 96)...), - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((1, 7), 64, 64; pad = (0, 3))..., - conv_norm((7, 1), 64, 64; pad = (3, 0))..., - conv_norm((3, 3), 64, 96)...)) -end - -function mixed_5a() - return Parallel(cat_channels, - Chain(conv_norm((3, 3), 192, 192; stride = 2)...), - MaxPool((3, 3); stride = 2)) -end - -function inceptionv4_a() - branch1 = Chain(conv_norm((1, 1), 384, 96)...) - branch2 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_a() - branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 384, 192)..., - conv_norm((3, 3), 192, 224; pad = 1)..., - conv_norm((3, 3), 224, 256; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_b() - branch1 = Chain(conv_norm((1, 1), 1024, 384)...) - branch2 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((1, 7), 192, 224; pad = (0, 3))..., - conv_norm((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((7, 1), 192, 192; pad = (0, 3))..., - conv_norm((1, 7), 192, 224; pad = (3, 0))..., - conv_norm((7, 1), 224, 224; pad = (0, 3))..., - conv_norm((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_b() - branch1 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1024, 256)..., - conv_norm((1, 7), 256, 256; pad = (0, 3))..., - conv_norm((7, 1), 256, 320; pad = (3, 0))..., - conv_norm((3, 3), 320, 320; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_c() - branch1 = Chain(conv_norm((1, 1), 1536, 256)...) - branch2 = Chain(conv_norm((1, 1), 1536, 384)..., - Parallel(cat_channels, - Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_norm((1, 1), 1536, 384)..., - conv_norm((3, 1), 384, 448; pad = (1, 0))..., - conv_norm((1, 3), 448, 512; pad = (0, 1))..., - Parallel(cat_channels, - Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -""" - inceptionv4(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) - -Create an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, - nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - mixed_3a(), - mixed_4a(), - mixed_5a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - reduction_a(), # mixed_6a - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - reduction_b(), # mixed_7a - inceptionv4_c(), - inceptionv4_c(), - inceptionv4_c()) - return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) -end - -""" - Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - -Creates an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: number of input channels. - - `nclasses`: the number of output classes. - -!!! warning - - `Inceptionv4` does not currently support pretrained weights. -""" -struct Inceptionv4 - layers::Any -end -@functor Inceptionv4 - -function Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) - layers = inceptionv4(; inchannels, nclasses) - if pretrain - loadpretrain!(layers, "Inceptionv4") - end - return Inceptionv4(layers) -end - -(m::Inceptionv4)(x) = m.layers(x) - -backbone(m::Inceptionv4) = m.layers[1] -classifier(m::Inceptionv4) = m.layers[2] diff --git a/src/convnets/inception/googlenet.jl b/src/convnets/inceptions/googlenet.jl similarity index 100% rename from src/convnets/inception/googlenet.jl rename to src/convnets/inceptions/googlenet.jl diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl similarity index 50% rename from src/convnets/inception/inceptionresnetv2.jl rename to src/convnets/inceptions/inceptionresnetv2.jl index a8bfbaefa..bd88648e9 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -1,70 +1,70 @@ function mixed_5b() - branch1 = Chain(conv_norm((1, 1), 192, 96)...) - branch2 = Chain(conv_norm((1, 1), 192, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_norm((1, 1), 192, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) + branch1 = Chain(basic_conv_bn((1, 1), 192, 96)...) + branch2 = Chain(basic_conv_bn((1, 1), 192, 48)..., + basic_conv_bn((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(basic_conv_bn((1, 1), 192, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), 192, 64)...) + basic_conv_bn((1, 1), 192, 64)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block35(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 320, 32)...) - branch2 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 48; pad = 1)..., - conv_norm((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_norm((1, 1), 128, 320)...) + branch1 = Chain(basic_conv_bn((1, 1), 320, 32)...) + branch2 = Chain(basic_conv_bn((1, 1), 320, 32)..., + basic_conv_bn((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(basic_conv_bn((1, 1), 320, 32)..., + basic_conv_bn((3, 3), 32, 48; pad = 1)..., + basic_conv_bn((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(basic_conv_bn((1, 1), 128, 320)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), branch4, inputscale(scale; activation = relu)), +) end function mixed_6a() - branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 320, 256)..., - conv_norm((3, 3), 256, 256; pad = 1)..., - conv_norm((3, 3), 256, 384; stride = 2)...) + branch1 = Chain(basic_conv_bn((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 320, 256)..., + basic_conv_bn((3, 3), 256, 256; pad = 1)..., + basic_conv_bn((3, 3), 256, 384; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function block17(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 1088, 192)...) - branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((1, 7), 128, 160; pad = (0, 3))..., - conv_norm((7, 1), 160, 192; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 384, 1088)...) + branch1 = Chain(basic_conv_bn((1, 1), 1088, 192)...) + branch2 = Chain(basic_conv_bn((1, 1), 1088, 128)..., + basic_conv_bn((7, 1), 128, 160; pad = (3, 0))..., + basic_conv_bn((1, 7), 160, 192; pad = (0, 3))...) + branch3 = Chain(basic_conv_bn((1, 1), 384, 1088)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation = relu)), +) end function mixed_7a() - branch1 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; pad = 1)..., - conv_norm((3, 3), 288, 320; stride = 2)...) + branch1 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(basic_conv_bn((1, 1), 1088, 256)..., + basic_conv_bn((3, 3), 256, 288; pad = 1)..., + basic_conv_bn((3, 3), 288, 320; stride = 2)...) branch4 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_norm((1, 1), 2080, 192)...) - branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((1, 3), 192, 224; pad = (0, 1))..., - conv_norm((3, 1), 224, 256; pad = (1, 0))...) - branch3 = Chain(conv_norm((1, 1), 448, 2080)...) + branch1 = Chain(basic_conv_bn((1, 1), 2080, 192)...) + branch2 = Chain(basic_conv_bn((1, 1), 2080, 192)..., + basic_conv_bn((3, 1), 192, 224; pad = (1, 0))..., + basic_conv_bn((1, 3), 224, 256; pad = (0, 1))...) + branch3 = Chain(basic_conv_bn((1, 1), 448, 2080)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation)), +) end """ - inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -72,17 +72,17 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, +function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_norm((3, 3), 64, 80)..., - conv_norm((3, 3), 80, 192)..., + basic_conv_bn((3, 3), 64, 80)..., + basic_conv_bn((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), mixed_5b(), [block35(0.17f0) for _ in 1:10]..., @@ -91,13 +91,13 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, mixed_7a(), [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), - conv_norm((1, 1), 2080, 1536)...) + basic_conv_bn((1, 1), 2080, 1536)...) return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end """ - InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) + InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -106,7 +106,6 @@ Creates an InceptionResNetv2 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl similarity index 60% rename from src/convnets/inception/inceptionv3.jl rename to src/convnets/inceptions/inceptionv3.jl index bc5ec3a2b..32fbbede5 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inceptions/inceptionv3.jl @@ -10,14 +10,14 @@ Create an Inception-v3 style-A module - `pool_proj`: the number of output feature maps for the pooling projection """ function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 64)...) + branch5x5 = Chain(basic_conv_bn((1, 1), inplanes, 48)..., + basic_conv_bn((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(basic_conv_bn((1, 1), inplanes, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, pool_proj)...) + basic_conv_bn((1, 1), inplanes, pool_proj)...) return Parallel(cat_channels, branch1x1, branch5x5, branch3x3, branch_pool) end @@ -33,10 +33,10 @@ Create an Inception-v3 style-B module - `inplanes`: number of input feature maps """ function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; stride = 2)...) + branch3x3_1 = Chain(basic_conv_bn((3, 3), inplanes, 384; stride = 2)...) + branch3x3_2 = Chain(basic_conv_bn((1, 1), inplanes, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3_1, branch3x3_2, branch_pool) @@ -55,17 +55,17 @@ Create an Inception-v3 style-C module - `n`: the "grid size" (kernel size) for the convolution layers """ function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 192)...) + branch7x7_1 = Chain(basic_conv_bn((1, 1), inplanes, inner_planes)..., + basic_conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + basic_conv_bn((1, n), inner_planes, 192; pad = (0, 3))...) + branch7x7_2 = Chain(basic_conv_bn((1, 1), inplanes, inner_planes)..., + basic_conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., + basic_conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + basic_conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., + basic_conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) + basic_conv_bn((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, branch7x7_1, branch7x7_2, branch_pool) end @@ -81,12 +81,12 @@ Create an Inception-v3 style-D module - `inplanes`: number of input feature maps """ function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((1, 7), 192, 192; pad = (0, 3))..., - conv_norm((7, 1), 192, 192; pad = (3, 0))..., - conv_norm((3, 3), 192, 192; stride = 2)...) + branch3x3 = Chain(basic_conv_bn((1, 1), inplanes, 192)..., + basic_conv_bn((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(basic_conv_bn((1, 1), inplanes, 192)..., + basic_conv_bn((7, 1), 192, 192; pad = (3, 0))..., + basic_conv_bn((1, 7), 192, 192; pad = (0, 3))..., + basic_conv_bn((3, 3), 192, 192; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3, branch7x7x3, branch_pool) @@ -103,16 +103,16 @@ Create an Inception-v3 style-E module - `inplanes`: number of input feature maps """ function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., - conv_norm((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch1x1 = Chain(basic_conv_bn((1, 1), inplanes, 320)...) + branch3x3_1 = Chain(basic_conv_bn((1, 1), inplanes, 384)...) + branch3x3_1a = Chain(basic_conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_1b = Chain(basic_conv_bn((1, 3), 384, 384; pad = (0, 1))...) + branch3x3_2 = Chain(basic_conv_bn((1, 1), inplanes, 448)..., + basic_conv_bn((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(basic_conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_2b = Chain(basic_conv_bn((1, 3), 384, 384; pad = (0, 1))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) + basic_conv_bn((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, Chain(branch3x3_1, @@ -133,13 +133,14 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes """ -function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., +function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, + nclasses::Integer = 1000) + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_norm((1, 1), 64, 80)..., - conv_norm((3, 3), 80, 192)..., + basic_conv_bn((1, 1), 64, 80)..., + basic_conv_bn((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), inceptionv3_a(192, 32), inceptionv3_a(256, 64), diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl new file mode 100644 index 000000000..13d40da25 --- /dev/null +++ b/src/convnets/inceptions/inceptionv4.jl @@ -0,0 +1,147 @@ +function mixed_3a() + return Parallel(cat_channels, + MaxPool((3, 3); stride = 2), + Chain(basic_conv_bn((3, 3), 64, 96; stride = 2)...)) +end + +function mixed_4a() + return Parallel(cat_channels, + Chain(basic_conv_bn((1, 1), 160, 64)..., + basic_conv_bn((3, 3), 64, 96)...), + Chain(basic_conv_bn((1, 1), 160, 64)..., + basic_conv_bn((7, 1), 64, 64; pad = (3, 0))..., + basic_conv_bn((1, 7), 64, 64; pad = (0, 3))..., + basic_conv_bn((3, 3), 64, 96)...)) +end + +function mixed_5a() + return Parallel(cat_channels, + Chain(basic_conv_bn((3, 3), 192, 192; stride = 2)...), + MaxPool((3, 3); stride = 2)) +end + +function inceptionv4_a() + branch1 = Chain(basic_conv_bn((1, 1), 384, 96)...) + branch2 = Chain(basic_conv_bn((1, 1), 384, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(basic_conv_bn((1, 1), 384, 64)..., + basic_conv_bn((3, 3), 64, 96; pad = 1)..., + basic_conv_bn((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 384, 96)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_a() + branch1 = Chain(basic_conv_bn((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 384, 192)..., + basic_conv_bn((3, 3), 192, 224; pad = 1)..., + basic_conv_bn((3, 3), 224, 256; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_b() + branch1 = Chain(basic_conv_bn((1, 1), 1024, 384)...) + branch2 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((7, 1), 192, 224; pad = (0, 3))..., + basic_conv_bn((1, 7), 224, 256; pad = (3, 0))...) + branch3 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((1, 7), 192, 192; pad = (3, 0))..., + basic_conv_bn((7, 1), 192, 224; pad = (0, 3))..., + basic_conv_bn((1, 7), 224, 224; pad = (3, 0))..., + basic_conv_bn((7, 1), 224, 256; pad = (0, 3))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 1024, 128)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_b() + branch1 = Chain(basic_conv_bn((1, 1), 1024, 192)..., + basic_conv_bn((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(basic_conv_bn((1, 1), 1024, 256)..., + basic_conv_bn((7, 1), 256, 256; pad = (3, 0))..., + basic_conv_bn((1, 7), 256, 320; pad = (0, 3))..., + basic_conv_bn((3, 3), 320, 320; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_c() + branch1 = Chain(basic_conv_bn((1, 1), 1536, 256)...) + branch2 = Chain(basic_conv_bn((1, 1), 1536, 384)..., + Parallel(cat_channels, + Chain(basic_conv_bn((3, 1), 384, 256; pad = (1, 0))...), + Chain(basic_conv_bn((1, 3), 384, 256; pad = (0, 1))...))) + branch3 = Chain(basic_conv_bn((1, 1), 1536, 384)..., + basic_conv_bn((1, 3), 384, 448; pad = (0, 1))..., + basic_conv_bn((3, 1), 448, 512; pad = (1, 0))..., + Parallel(cat_channels, + Chain(basic_conv_bn((3, 1), 512, 256; pad = (1, 0))...), + Chain(basic_conv_bn((1, 3), 512, 256; pad = (0, 1))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), + basic_conv_bn((1, 1), 1536, 256)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +""" + inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) + +Create an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. + - `nclasses`: the number of output classes. +""" +function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) + backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., + basic_conv_bn((3, 3), 32, 32)..., + basic_conv_bn((3, 3), 32, 64; pad = 1)..., + mixed_3a(), mixed_4a(), mixed_5a(), + [inceptionv4_a() for _ in 1:4]..., + reduction_a(), # mixed_6a + [inceptionv4_b() for _ in 1:7]..., + reduction_b(), # mixed_7a + [inceptionv4_c() for _ in 1:3]...) + return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) +end + +""" + Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) + +Creates an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `nclasses`: the number of output classes. + +!!! warning + + `Inceptionv4` does not currently support pretrained weights. +""" +struct Inceptionv4 + layers::Any +end +@functor Inceptionv4 + +function Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) + layers = inceptionv4(; inchannels, nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv4") + end + return Inceptionv4(layers) +end + +(m::Inceptionv4)(x) = m.layers(x) + +backbone(m::Inceptionv4) = m.layers[1] +classifier(m::Inceptionv4) = m.layers[2] diff --git a/src/convnets/inception/xception.jl b/src/convnets/inceptions/xception.jl similarity index 83% rename from src/convnets/inception/xception.jl rename to src/convnets/inceptions/xception.jl index 4964f3ca1..171bddd19 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -8,7 +8,7 @@ Create an Xception block. # Arguments - - `inchannels`: The number of channels in the input. + - `inchannels`: number of input channels - `outchannels`: number of output channels. - `nrepeats`: number of repeats of depthwise separable convolution layers. - `stride`: stride by which to downsample the input. @@ -19,8 +19,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int stride::Integer = 1, start_with_relu::Bool = true, grow_at_start::Bool = true) if outchannels != inchannels || stride != 1 - skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride) else skip = [identity] end @@ -35,8 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, - use_norm = (false, false))) + dwsep_conv_bn((3, 3), inc, outc; pad = 1, use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -45,27 +43,27 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end """ - xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) + xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) - backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_norm((3, 3), 32, 64; bias = false)..., + backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 64)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), xception_block(128, 256, 2; stride = 2), xception_block(256, 728, 2; stride = 2), [xception_block(728, 728, 3) for _ in 1:8]..., xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...) + dwsep_conv_bn((3, 3), 1024, 1536; pad = 1)..., + dwsep_conv_bn((3, 3), 1536, 2048; pad = 1)...) return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl deleted file mode 100644 index 84162e985..000000000 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ /dev/null @@ -1,98 +0,0 @@ -""" - mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) - -Create a MobileNetv2 model. -([reference](https://arxiv.org/abs/1801.04381)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer - + `c`: The number of output feature maps - + `n`: The number of times a block is repeated - + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: The number of output classes -""" -function mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; - max_width::Integer = 1280, inchannels::Integer = 3, - nclasses::Integer = 1000) - divisor = width_mult == 0.1 ? 4 : 8 - # building first layer - inplanes = _round_channels(32 * width_mult, divisor) - layers = [] - append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) - # building inverted residual blocks - for (t, c, n, s, a) in configs - outplanes = _round_channels(c * width_mult, divisor) - for i in 1:n - push!(layers, - invertedresidual((3, 3), inplanes, outplanes, a; expansion = t, - stride = i == 1 ? s : 1)) - inplanes = outplanes - end - end - # building last layers - outplanes = width_mult > 1 ? _round_channels(max_width * width_mult, divisor) : - max_width - append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) -end - -# Layer configurations for MobileNetv2 -const MOBILENETV2_CONFIGS = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), -] - -""" - MobileNetv2(width_mult = 1.0; inchannels::Integer = 3, pretrain::Bool = false, - nclasses::Integer = 1000) - -Create a MobileNetv2 model with the specified configuration. -([reference](https://arxiv.org/abs/1801.04381)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `inchannels`: The number of input channels. - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv2`](#). -""" -struct MobileNetv2 - layers::Any -end -@functor MobileNetv2 - -function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("MobileNetv2")) - end - return MobileNetv2(layers) -end - -(m::MobileNetv2)(x) = m.layers(x) - -backbone(m::MobileNetv2) = m.layers[1] -classifier(m::MobileNetv2) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl deleted file mode 100644 index 7d06ab14d..000000000 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ /dev/null @@ -1,133 +0,0 @@ -""" - mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) - -Create a MobileNetv3 model. -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - - `configs`: a "list of tuples" configuration for each layer that details: - - + `k::Integer` - The size of the convolutional kernel - + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - + `t::Integer` - The number of output feature maps for a given block - + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - + `s::Integer` - The stride of the convolutional kernel - + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: the number of output classes -""" -function mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) - # building first layer - inplanes = _round_channels(16 * width_mult, 8) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, - bias = false)) - explanes = 0 - # building inverted residual blocks - for (k, t, c, r, a, s) in configs - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - invertedresidual((k, k), inplanes, explanes, outplanes, a; - stride = s, reduction = r)) - inplanes = outplanes - end - # building last layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel - append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)) - classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses)) - return Chain(Chain(layers...), classifier) -end - -# Layer configurations for small and large models for MobileNetv3 -const MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) - -""" - MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create a MobileNetv3 model with the specified configuration. -([reference](https://arxiv.org/abs/1905.02244)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -# Arguments - - - `config`: :small or :large for the size of the model (see paper). - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. - - `nclasses`: the number of output classes - -See also [`Metalhead.mobilenetv3`](#). -""" -struct MobileNetv3 - layers::Any -end -@functor MobileNetv3 - -function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, [:small, :large]) - max_width = (config == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[config]; max_width, inchannels, - nclasses) - if pretrain - loadpretrain!(layers, string("MobileNetv3", config)) - end - return MobileNetv3(layers) -end - -(m::MobileNetv3)(x) = m.layers(x) - -backbone(m::MobileNetv3) = m.layers[1] -classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl new file mode 100644 index 000000000..2f6db2acf --- /dev/null +++ b/src/convnets/mobilenets/mnasnet.jl @@ -0,0 +1,125 @@ +# momentum used for BatchNorm as per Tensorflow implementation +const _MNASNET_BN_MOMENTUM = 0.0003f0 + +""" + mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width = 1280, dropout_rate = 0.2, inchannels::Integer = 3, + nclasses::Integer = 1000) + +Create an MNASNet model with the specified configuration. +([reference](https://arxiv.org/abs/1807.11626)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper) + - `max_width`: The maximum number of feature maps in any layer of the network + - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes +""" +function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1280, inplanes::Integer = 32, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + # norm layer for MNASNet is different from other models + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = _MNASNET_BN_MOMENTUM, + kwargs...) + # building first layer + inplanes = _round_channels(inplanes * width_mult) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, relu; stride = 2, pad = 1, + norm_layer)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, + norm_layer) + append!(layers, cnn_stages(get_layers, block_repeats, +)) + # building last layers + outplanes = _round_channels(block_configs[end][3] * width_mult) + append!(layers, + conv_norm((1, 1), outplanes, max_width, relu; norm_layer)) + return Chain(Chain(layers...), create_classifier(max_width, nclasses; dropout_rate)) +end + +# Layer configurations for MNasNet +# f: block function - we use `dwsep_conv_bn` for the first block and `mbconv` for the rest +# k: kernel size +# c: output channels +# e: expansion factor - only used for `mbconv` +# s: stride +# n: number of repeats +# r: reduction factor - only used for `mbconv` +# a: activation function +# Data is organised as (f, k, c, (e,) s, n, (r,) a) +const MNASNET_CONFIGS = Dict(:B1 => (32, + [ + (dwsep_conv_bn, 3, 16, 1, 1, relu), + (mbconv, 3, 24, 3, 2, 3, nothing, relu), + (mbconv, 5, 40, 3, 2, 3, nothing, relu), + (mbconv, 5, 80, 6, 2, 3, nothing, relu), + (mbconv, 3, 96, 6, 1, 2, nothing, relu), + (mbconv, 5, 192, 6, 2, 4, nothing, relu), + (mbconv, 3, 320, 6, 1, 1, nothing, relu), + ]), + :A1 => (32, + [ + (dwsep_conv_bn, 3, 16, 1, 1, relu), + (mbconv, 3, 24, 6, 2, 2, nothing, relu), + (mbconv, 5, 40, 3, 2, 3, 4, relu), + (mbconv, 3, 80, 6, 2, 4, nothing, relu), + (mbconv, 3, 112, 6, 1, 2, 4, relu), + (mbconv, 5, 160, 6, 2, 3, 4, relu), + (mbconv, 3, 320, 6, 1, 1, nothing, relu), + ]), + :small => (8, + [ + (dwsep_conv_bn, 3, 8, 1, 1, relu), + (mbconv, 3, 16, 3, 2, 1, nothing, relu), + (mbconv, 3, 16, 6, 2, 2, nothing, relu), + (mbconv, 5, 32, 6, 2, 4, 4, relu), + (mbconv, 3, 32, 6, 1, 3, 4, relu), + (mbconv, 5, 88, 6, 2, 3, 4, relu), + (mbconv, 3, 144, 6, 1, 1, nothing, relu)])) + +""" + MNASNet(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false, + nclasses::Integer = 1000) + +Creates a MNASNet model with the specified configuration. +([reference](https://arxiv.org/abs/1807.11626)) + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes + +!!! warning + + `MNASNet` does not currently support pretrained weights. + +See also [`mnasnet`](#). +""" +struct MNASNet + layers::Any +end +@functor MNASNet + +function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, keys(MNASNET_CONFIGS)) + inplanes, block_configs = MNASNET_CONFIGS[config] + layers = mnasnet(block_configs; width_mult, inplanes, inchannels, nclasses) + if pretrain + loadpretrain!(layers, "mnasnet$(width_mult)") + end + return MNASNet(layers) +end + +(m::MNASNet)(x) = m.layers(x) + +backbone(m::MNASNet) = m.layers[1] +classifier(m::MNASNet) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl similarity index 51% rename from src/convnets/mobilenet/mobilenetv1.jl rename to src/convnets/mobilenets/mobilenetv1.jl index b6d9fe8ee..24240d0c0 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -1,5 +1,6 @@ """ - mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, + mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). @@ -7,7 +8,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) + (with 1 being the default in the paper) - `configs`: A "list of tuples" configuration for each layer that details: @@ -16,60 +17,63 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ -function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, - inchannels::Integer = 3, nclasses::Integer = 1000) +function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1, + activation = relu, dropout_rate = nothing, + inplanes::Integer = 32, inchannels::Integer = 3, + nclasses::Integer = 1000) layers = [] - for (dw, outch, stride, nrepeats) in config - outch = floor(Int, outch * width_mult) - for _ in 1:nrepeats - layer = dw ? - depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, - bias = false) - append!(layers, layer) - inchannels = outch - end - end - return Chain(Chain(layers...), create_classifier(inchannels, nclasses)) + # stem of the model + inplanes = _round_channels(inplanes * width_mult) + append!(layers, + conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(config, inplanes; width_mult) + append!(layers, cnn_stages(get_layers, block_repeats)) + outplanes = _round_channels(config[end][3] * width_mult) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 +# f: block function - we use `dwsep_conv_bn` for all blocks +# k: kernel size +# c: output channels +# s: stride +# n: number of repeats +# a: activation function const MOBILENETV1_CONFIGS = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), + # f, k, c, s, n, a + (dwsep_conv_bn, 3, 64, 1, 1, relu6), + (dwsep_conv_bn, 3, 128, 2, 2, relu6), + (dwsep_conv_bn, 3, 256, 2, 2, relu6), + (dwsep_conv_bn, 3, 512, 2, 6, relu6), + (dwsep_conv_bn, 3, 1024, 2, 2, relu6), ] """ - MobileNetv1(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false, - nclasses::Integer = 1000) + MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model with the baseline configuration ([reference](https://arxiv.org/abs/1704.04861v1)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inchannels`: The number of input channels. - `nclasses`: The number of output classes -See also [`Metalhead.mobilenetv1`](#). +!!! warning + + `MobileNetv1` does not currently support pretrained weights. + +See also [`mobilenetv1`](#). """ struct MobileNetv1 layers::Any @@ -78,7 +82,7 @@ end function MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses) + layers = mobilenetv1(MOBILENETV1_CONFIGS; width_mult, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv1")) end diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl new file mode 100644 index 000000000..f3e26862c --- /dev/null +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -0,0 +1,105 @@ +""" + mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create a MobileNetv2 model. +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer + + `c`: The number of output feature maps + + `n`: The number of times a block is repeated + + `s`: The stride of the convolutional kernel + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper) + - `max_width`: The maximum number of feature maps in any layer of the network + - `divisor`: The divisor used to round the number of feature maps in each block + - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes +""" +function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1280, divisor::Integer = 8, + inplanes::Integer = 32, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + # building first layer + inplanes = _round_channels(inplanes * width_mult, divisor) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, + divisor) + append!(layers, cnn_stages(get_layers, block_repeats, +)) + # building last layers + outplanes = _round_channels(block_configs[end][3] * width_mult, divisor) + headplanes = _round_channels(max_width * max(1, width_mult), divisor) + append!(layers, conv_norm((1, 1), outplanes, headplanes, relu6)) + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) +end + +# Layer configurations for MobileNetv2 +# f: block function - we use `mbconv` for all blocks +# k: kernel size +# c: output channels +# e: expansion factor +# s: stride +# n: number of repeats +# r: reduction factor +# a: activation function +const MOBILENETV2_CONFIGS = [ + # f, k, c, e, s, n, r, a + (mbconv, 3, 16, 1, 1, 1, nothing, swish), + (mbconv, 3, 24, 6, 2, 2, nothing, swish), + (mbconv, 3, 32, 6, 2, 3, nothing, swish), + (mbconv, 3, 64, 6, 2, 4, nothing, swish), + (mbconv, 3, 96, 6, 1, 3, nothing, swish), + (mbconv, 3, 160, 6, 2, 3, nothing, swish), + (mbconv, 3, 320, 6, 1, 1, nothing, swish), +] + +""" + MobileNetv2(width_mult = 1.0; inchannels::Integer = 3, pretrain::Bool = false, + nclasses::Integer = 1000) + +Create a MobileNetv2 model with the specified configuration. +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes + +!!! warning + + `MobileNetv2` does not currently support pretrained weights. + +See also [`mobilenetv2`](#). +""" +struct MobileNetv2 + layers::Any +end +@functor MobileNetv2 + +function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + layers = mobilenetv2(MOBILENETV2_CONFIGS; width_mult, inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end + return MobileNetv2(layers) +end + +(m::MobileNetv2)(x) = m.layers(x) + +backbone(m::MobileNetv2) = m.layers[1] +classifier(m::MobileNetv2) = m.layers[2] diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl new file mode 100644 index 000000000..2614c7c2f --- /dev/null +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -0,0 +1,125 @@ +""" + mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create a MobileNetv3 model. +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `configs`: a "list of tuples" configuration for each layer that details: + + + `k::Integer` - The size of the convolutional kernel + + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer + + `t::Integer` - The number of output feature maps for a given block + + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + + `s::Integer` - The stride of the convolutional kernel + + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) + - `max_width`: The maximum number of feature maps in any layer of the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. + - `inchannels`: The number of input channels. + - `nclasses`: the number of output classes +""" +function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + # building first layer + inplanes = _round_channels(16 * width_mult) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult, + se_from_explanes = true, + se_round_fn = _round_channels) + append!(layers, cnn_stages(get_layers, block_repeats, +)) + # building last layers + explanes = _round_channels(configs[end][3] * width_mult) + midplanes = _round_channels(explanes * configs[end][4]) + append!(layers, conv_norm((1, 1), explanes, midplanes, hardswish)) + return Chain(Chain(layers...), + create_classifier(midplanes, max_width, nclasses, + (hardswish, identity); dropout_rate)) +end + +# Layer configurations for small and large models for MobileNetv3 +# f: mbconv block function - we use `mbconv` for all blocks +# k: kernel size +# c: output channels +# e: expansion factor +# s: stride +# n: number of repeats +# r: squeeze and excite reduction factor +# a: activation function +# Data is organised as (f, k, c, e, s, n, r, a) +const MOBILENETV3_CONFIGS = Dict(:small => [ + (mbconv, 3, 16, 1, 2, 1, 4, relu), + (mbconv, 3, 24, 4.5, 2, 1, nothing, relu), + (mbconv, 3, 24, 3.67, 1, 1, nothing, relu), + (mbconv, 5, 40, 4, 2, 1, 4, hardswish), + (mbconv, 5, 40, 6, 1, 2, 4, hardswish), + (mbconv, 5, 48, 3, 1, 2, 4, hardswish), + (mbconv, 5, 96, 6, 1, 3, 4, hardswish), + ], + :large => [ + (mbconv, 3, 16, 1, 1, 1, nothing, relu), + (mbconv, 3, 24, 4, 2, 1, nothing, relu), + (mbconv, 3, 24, 3, 1, 1, nothing, relu), + (mbconv, 5, 40, 3, 2, 1, 4, relu), + (mbconv, 5, 40, 3, 1, 2, 4, relu), + (mbconv, 3, 80, 6, 2, 1, nothing, hardswish), + (mbconv, 3, 80, 2.5, 1, 1, nothing, hardswish), + (mbconv, 3, 80, 2.3, 1, 2, nothing, hardswish), + (mbconv, 3, 112, 6, 1, 2, 4, hardswish), + (mbconv, 5, 160, 6, 1, 3, 4, hardswish), + ]) + +""" + MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create a MobileNetv3 model with the specified configuration. +([reference](https://arxiv.org/abs/1905.02244)). +Set `pretrain = true` to load the model with pre-trained weights for ImageNet. + +# Arguments + + - `config`: :small or :large for the size of the model (see paper). + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels + - `nclasses`: the number of output classes + +!!! warning + + `MobileNetv3` does not currently support pretrained weights. + +See also [`mobilenetv3`](#). +""" +struct MobileNetv3 + layers::Any +end +@functor MobileNetv3 + +function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, [:small, :large]) + max_width = config == :large ? 1280 : 1024 + layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels, + nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv3", config)) + end + return MobileNetv3(layers) +end + +(m::MobileNetv3)(x) = m.layers(x) + +backbone(m::MobileNetv3) = m.layers[1] +classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 699edcbe8..dc143b57e 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -12,7 +12,7 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385 - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `reduction_factor`: the factor by which the input feature maps are reduced before + - `reduction_factor`: the factor by which the input feature maps are reduced before the first convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. @@ -27,12 +27,11 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor - outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, - stride, pad = 1, bias = false) - conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, - pad = 1, bias = false) - layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), + stride, pad = 1) + conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm, + pad = 1) + layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes), drop_path] return Chain(filter!(!=(identity), layers)...) end @@ -72,12 +71,11 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, - bias = false) + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, + revnorm) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, - stride, pad = 1, groups = cardinality, bias = false) - conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm, - bias = false) + stride, pad = 1, groups = cardinality) + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) @@ -85,18 +83,18 @@ end # Downsample layer using convolutions. function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, revnorm = false) + norm_layer = BatchNorm, revnorm::Bool = false) return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - pad = SamePad(), stride, bias = false)...) + pad = SamePad(), stride)...) end # Downsample layer using max pooling function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, revnorm = false) - pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) + norm_layer = BatchNorm, revnorm::Bool = false) + pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, - bias = false)...) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, + revnorm)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -117,13 +115,13 @@ end # Shortcut configurations for the ResNet models const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), - :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv), - :D => (downsample_pool, downsample_identity)) + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) # Stride for each block in the ResNet model function resnet_stride(stage_idx::Integer, block_idx::Integer) - return (stage_idx == 1 || block_idx != 1) ? 1 : 2 + return stage_idx == 1 || block_idx != 1 ? 1 : 2 end # returns `DropBlock`s for each stage of the ResNet as in timm. @@ -156,8 +154,8 @@ on how to use this function. convolution. This is an experimental variant from the `timm` library in Python that shows peformance improvements over the `:deep` stem in some cases. - - `inchannels`: The number of channels in the input. - - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + + - `inchannels`: number of input channels + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. @@ -165,8 +163,7 @@ on how to use this function. function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, replace_pool::Bool = false, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false) - @assert stem_type in [:default, :deep, :deep_tiered] - "Stem type must be one of [:default, :deep, :deep_tiered]" + _checkconfig(stem_type, [:default, :deep, :deep_tiered]) # Main stem deep_stem = stem_type == :deep || stem_type == :deep_tiered inplanes = deep_stem ? stem_width * 2 : 64 @@ -178,9 +175,9 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; - norm_layer, revnorm, stride = 2, pad = 1, bias = false)..., + norm_layer, revnorm, stride = 2, pad = 1)..., conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; - norm_layer, pad = 1, bias = false)..., + norm_layer, pad = 1)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) @@ -189,8 +186,7 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - revnorm, - stride = 2, pad = 1, bias = false)...) : + revnorm, stride = 2, pad = 1)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool) end @@ -200,93 +196,11 @@ function resnet_planes(block_repeats::AbstractVector{<:Integer}) for (stage_idx, stages) in enumerate(block_repeats)) end -function basicblock_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, reduction_factor::Integer = 1, - expansion::Integer = 1, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # This is also needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? - downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - block = basicblock(inplanes, planes; stride, reduction_factor, activation, - norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - -function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, cardinality::Integer = 1, - base_width::Integer = 64, reduction_factor::Integer = 1, - expansion::Integer = 4, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule - # This is also needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? - downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - block = bottleneck(inplanes, planes; stride, cardinality, base_width, - reduction_factor, activation, norm_layer, revnorm, - attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - -function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) - # Construct each stage - stages = [] - for (stage_idx, num_blocks) in enumerate(block_repeats) - # Construct the blocks for each stage - blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...) - for block_idx in 1:num_blocks] - push!(stages, Chain(blocks...)) - end - return Chain(stages...) -end - function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection) - backbone = Chain(stem, stage_blocks) + stage_blocks = cnn_stages(get_layers, block_repeats, connection) + backbone = Chain(stem, stage_blocks...) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] return Chain(backbone, classifier_fn(nfeaturemaps)) @@ -294,13 +208,14 @@ end function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); - cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, + cardinality::Integer = 1, base_width::Integer = 64, + inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) + use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, + dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) # Block builder @@ -312,26 +227,26 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, drop_block_rate, drop_path_rate, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_block_rate, drop_path_rate, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`" - @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`" - @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`" + @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. + Set `drop_block_rate` to nothing." + @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. + Set `drop_path_rate` to nothing." + @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. + Set `reduction_factor` to 1." get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, - kwargs...) + downsample_tuple = downsample_opt, kwargs...) else # TODO: write better message when we have link to dev docs for resnet throw(ArgumentError("Unknown block type $block_type")) @@ -351,7 +266,8 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) - +# block configurations for larger ResNet-like models that do not use +# depths 18 and 34 const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 8e054da82..e308e1125 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -8,6 +8,7 @@ Creates a bottleneck block as described in the Res2Net paper. ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block @@ -29,21 +30,19 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, outplanes = planes * 4 pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride, - pad = 1, groups = cardinality, bias = false)...) + pad = 1, groups = cardinality)...) for _ in 1:max(1, scale - 1)] reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) - tuplify = if is_first - x -> tuple(x...) - else - x -> tuple(x[1], tuple(x[2:end]...)) - end - layers = [conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm, bias = false)..., - chunk$(; size = width, dims = 3), tuplify, reslayer, - conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm, bias = false)..., - attn_fn(outplanes)] + tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) + layers = [ + conv_norm((1, 1), inplanes => width * scale, activation; + norm_layer, revnorm)..., + chunk$(; size = width, dims = 3), tuplify, reslayer, + conv_norm((1, 1), width * scale => outplanes, activation; + norm_layer, revnorm)..., + attn_fn(outplanes), + ] return Chain(filter(!=(identity), layers)...) end @@ -86,6 +85,7 @@ Creates a Res2Net model with the specified depth, scale, and base width. ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `scale`: the number of feature groups in the block. See the @@ -125,6 +125,7 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina ([reference](https://arxiv.org/abs/1904.01169)) # Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `scale`: the number of feature groups in the block. See the diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index cdccddd4b..b65b71072 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -2,7 +2,7 @@ ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a ResNet model with the specified depth. -((reference)[https://arxiv.org/abs/1512.03385]) +([reference](https://arxiv.org/abs/1512.03385)) # Arguments @@ -39,14 +39,14 @@ classifier(m::ResNet) = m.layers[2] Creates a Wide ResNet model with the specified depth. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same. -((reference)[https://arxiv.org/abs/1605.07146]) +([reference](https://arxiv.org/abs/1605.07146)) # Arguments - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the Wide ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: The number of input channels. - - `nclasses`: the number of output classes + - `nclasses`: The number of output classes Advanced users who want more configuration options will be better served by using [`resnet`](#). """ diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 8c43d2f62..2a8fbd561 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -3,16 +3,18 @@ base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a ResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/abs/1611.05431]) +([reference](https://arxiv.org/abs/1611.05431)) # Arguments - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. Supported configurations are: - - depth 50, cardinality of 32 and base width of 4. - - depth 101, cardinality of 32 and base width of 8. - - depth 101, cardinality of 64 and base width of 4. + + + depth 50, cardinality of 32 and base width of 4. + + depth 101, cardinality of 32 and base width of 8. + + depth 101, cardinality of 64 and base width of 4. - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. - `inchannels`: the number of input channels. @@ -30,9 +32,11 @@ end function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) - layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, + base_width) if pretrain - loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d")) + loadpretrain!(layers, + string("resnext", depth, "_", cardinality, "x", base_width, "d")) end return ResNeXt(layers) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index da074e57d..487665518 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -2,7 +2,7 @@ SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a SEResNet model with the specified depth. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) +([reference](https://arxiv.org/pdf/1709.01507.pdf)) # Arguments @@ -22,8 +22,6 @@ struct SEResNet end @functor SEResNet -(m::SEResNet)(x) = m.layers(x) - function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(RESNET_CONFIGS)) @@ -35,6 +33,8 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = return SEResNet(layers) end +(m::SEResNet)(x) = m.layers(x) + backbone(m::SEResNet) = m.layers[1] classifier(m::SEResNet) = m.layers[2] @@ -43,7 +43,7 @@ classifier(m::SEResNet) = m.layers[2] base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) Creates a SEResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) +([reference](https://arxiv.org/pdf/1709.01507.pdf)) # Arguments @@ -65,12 +65,12 @@ struct SEResNeXt end @functor SEResNeXt -(m::SEResNeXt)(x) = m.layers(x) - function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, - base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) + base_width::Integer = 4, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) - layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, + base_width, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) @@ -78,5 +78,7 @@ function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer return SEResNeXt(layers) end +(m::SEResNeXt)(x) = m.layers(x) + backbone(m::SEResNeXt) = m.layers[1] classifier(m::SEResNeXt) = m.layers[2] diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index de232d9a3..163c13b68 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -17,7 +17,7 @@ function vgg_block(ifilters::Integer, ofilters::Integer, depth::Integer, batchno layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_norm(k, ifilters, ofilters; pad = p, bias = false)) + append!(layers, conv_norm(k, ifilters, ofilters; pad = p)) else push!(layers, Conv(k, ifilters => ofilters, relu; pad = p)) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 04be476ff..9bdf1f913 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -11,13 +11,15 @@ using MLUtils using PartialFunctions using Random +import Flux.testmode! + include("../utilities.jl") include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, depthwise_sep_conv_norm, invertedresidual +export conv_norm, basic_conv_bn, dwsep_conv_bn include("drop.jl") export DropBlock, DropPath @@ -25,8 +27,14 @@ export DropBlock, DropPath include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens +include("mbconv.jl") +export mbconv, fused_mbconv + include("mlp.jl") -export mlp_block, gated_mlp_block, create_fc, create_classifier +export mlp_block, gated_mlp_block + +include("classifier.jl") +export create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl new file mode 100644 index 000000000..bebdc4099 --- /dev/null +++ b/src/layers/classifier.jl @@ -0,0 +1,93 @@ +""" + create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + +Creates a classifier head to be used for models. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `activation`: activation function to use + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # Dropout is applied after the pooling layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # Fully-connected layer + if use_conv + push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) + else + push!(classifier, Dense(inplanes => nclasses, activation)) + end + return Chain(classifier...) +end + +""" + create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + +Creates a classifier head to be used for models with an extra hidden layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `hidden_planes`: number of hidden feature maps + - `nclasses`: number of output classes + - `activations`: activation functions to use for the hidden and output layers. This is a + tuple of two elements, the first being the activation function for the hidden layer and the + second for the output layer. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. This + is a tuple of two booleans, the first for the hidden layer and the second for the output + layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2, Any} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + fc_layers = [uc ? Conv$(1, 1) : Dense for uc in use_conv] + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv[1] && pool_layer !== identity + if use_conv[1] + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv[1]` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # first fully-connected layer + if !isnothing(hidden_planes) + push!(classifier, fc_layers[1](inplanes => hidden_planes, activations[1])) + end + # Dropout is applied after the first dense layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # second fully-connected layer + push!(classifier, fc_layers[2](hidden_planes => nclasses, activations[2])) + return Chain(classifier...) +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index de214bcbc..e49611280 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,10 +1,11 @@ """ - conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, - use_norm::Bool = true, stride::Integer = 1, pad::Integer = 0, - dilation::Integer = 1, groups::Integer = 1, [bias, weight, init]) + conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + stride::Integer = 1, pad::Integer = 0, dilation::Integer = 1, + groups::Integer = 1, [bias, weight, init]) - conv_norm(kernel_size, inplanes => outplanes, activation = identity; + conv_norm(kernel_size::Dims{2}, inplanes => outplanes, activation = identity; kwargs...) Create a convolution + batch normalization pair with activation. @@ -25,25 +26,32 @@ Create a convolution + batch normalization pair with activation. - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `bias`: bias for the convolution kernel. This is set to `false` by default if + `use_norm = true`. + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; - norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, - use_norm::Bool = true, kwargs...) +function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, + eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, + bias = !use_norm, kwargs...) + # no normalization layer if !use_norm - if (preact || revnorm) + if preact || revnorm throw(ArgumentError("`preact` only supported with `use_norm = true`")) else + # early return if no norm layer is required return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] end end + # channels for norm layer and activation functions for both conv and norm if revnorm activations = (conv = activation, bn = identity) - bnplanes = inplanes + normplanes = inplanes else activations = (conv = identity, bn = activation) - bnplanes = outplanes + normplanes = outplanes end + # handle pre-activation if preact if revnorm throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time")) @@ -51,101 +59,21 @@ function conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activatio activations = (conv = activation, bn = identity) end end - layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), - norm_layer(bnplanes, activations.bn)] + # layers + layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), + norm_layer(normplanes, activations.bn; ϵ = eps)] return revnorm ? reverse(layers) : layers end - -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; - kwargs...) +function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, + activation = identity; kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end -""" - depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) - -Create a depthwise separable convolution chain as used in MobileNetv1. -This is sequence of layers: - - - a `kernel_size` depthwise convolution from `inplanes => inplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise - `activation` is applied to the convolution output) - - a `kernel_size` convolution from `inplanes => outplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise - `activation` is applied to the convolution output) - -See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - -# Arguments - - - `kernel_size`: size of the convolution kernel (tuple) - - `inplanes`: number of input feature maps - - `outplanes`: number of output feature maps - - `activation`: the activation function for the final layer - - `revnorm`: set to `true` to place the batch norm before the convolution - - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and - second convolution - - `stride`: stride of the first convolution kernel - - `pad`: padding of the first convolution kernel - - `dilation`: dilation of the first convolution kernel - - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) -""" -function depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, stride::Integer = 1, - use_norm::NTuple{2, Bool} = (true, true), kwargs...) - return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; - norm_layer, revnorm, use_norm = use_norm[1], stride, - groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, - use_norm = use_norm[2])) -end - -""" - invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; - stride, reduction = nothing) - -Create a basic inverted residual block for MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `kernel_size`: kernel size of the convolutional layers - - `inplanes`: number of input feature maps - - `hidden_planes`: The number of feature maps in the hidden layer - - `outplanes`: The number of output feature maps - - `activation`: The activation function for the first two convolution layer - - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)). -""" -function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer, activation = relu; stride::Integer, - reduction::Union{Nothing, Integer} = nothing) - @assert stride in [1, 2] "`stride` has to be 1 or 2" - pad = @. (kernel_size - 1) ÷ 2 - conv1 = (inplanes == hidden_planes) ? (identity,) : - conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false) - selayer = isnothing(reduction) ? identity : - squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, - norm_layer = BatchNorm) - invres = Chain(conv1..., - conv_norm(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad = pad, groups = hidden_planes)..., - selayer, - conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) - return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres -end - -function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion::Real, - reduction::Union{Nothing, Integer} = nothing) - hidden_planes = floor(Int, inplanes * expansion) - return invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation; - stride, reduction) +# conv + bn layer combination as used by the inception model family matching +# the default values used in TensorFlow +function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu; + kwargs...) + return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, + eps = 1.0f-3, kwargs...) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 31c06c07a..8a82d5b16 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -20,7 +20,8 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. - `x`: input array - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -56,11 +57,25 @@ dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) The `DropBlock` layer. While training, it zeroes out continguous regions of size `block_size` in the input. During inference, it simply returns the input `x`. -((reference)[https://arxiv.org/abs/1810.12890]) +It can be used in two ways: either with all blocks having the same survival probability +or with a linear scaling rule across the blocks. This is performed only at training time. +At test time, the `DropBlock` layer is equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropBlock` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. + +([reference](https://arxiv.org/abs/1810.12890)) # Arguments - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -78,10 +93,8 @@ end trainable(a::DropBlock) = (;) function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_scale) - @assert 0 ≤ drop_block_prob ≤ 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert 0 ≤ gamma_scale ≤ 1 - return "gamma_scale must be between 0 and 1, got $gamma_scale" + @assert 0≤drop_block_prob≤1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0≤gamma_scale≤1 "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) @@ -90,11 +103,8 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_s function (m::DropBlock)(x) _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) - if Flux._isactive(m) - return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) - else - return x - end + return Flux._isactive(m) ? + dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) : x end function Flux.testmode!(m::DropBlock, mode = true) @@ -103,7 +113,7 @@ end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, rng = rng_from_array()) - if drop_block_prob == 0.0 + if isnothing(drop_block_prob) return identity end return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) @@ -116,11 +126,12 @@ function Base.show(io::IO, d::DropBlock) return print(io, ")") end +# TODO look into "row" mode for stochastic depth """ DropPath(p; [rng = rng_from_array(x)]) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and -`identity` otherwise. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 ≤ p ≤ 1` and +`identity` if p is `nothing`. ([reference](https://arxiv.org/abs/1603.09382)) This layer can be used to drop certain blocks in a residual structure and allow them to @@ -133,10 +144,10 @@ equivalent to `identity`. In the case of the linear scaling rule, the calculations of survival probabilities for each block may lead to a survival probability > 1 for a given block. This will lead to - `DropPath` returning `identity`, which may not be desirable. This usually happens with - a low number of blocks and a high base survival probability, so it is recommended to - use a fixed base survival probability across blocks. If this is not possible, then - a lower base survival probability is recommended. + `DropPath` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. # Arguments @@ -145,4 +156,6 @@ equivalent to `identity`. for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU. """ -DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity +function DropPath(p; rng = rng_from_array()) + return isnothing(p) ? identity : Dropout(p; dims = 4, rng) +end diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index cb9b8378c..abdab4b44 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -11,7 +11,7 @@ patches. # Arguments - `imsize`: the size of the input image - - `inchannels`: the number of channels in the input. + - `inchannels`: number of input channels - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding - `norm_layer`: the normalization layer - by default the identity function but otherwise takes a diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl new file mode 100644 index 000000000..e37a98406 --- /dev/null +++ b/src/layers/mbconv.jl @@ -0,0 +1,152 @@ +""" + dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, [bias, weight, init]) + +Create a depthwise separable convolution chain as used in MobileNetv1. +This is sequence of layers: + + - a `kernel_size` depthwise convolution from `inplanes => inplanes` + - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise + `activation` is applied to the convolution output) + - a `kernel_size` convolution from `inplanes => outplanes` + - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise + `activation` is applied to the convolution output) + +See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). + +# Arguments + + - `kernel_size`: size of the convolution kernel (tuple) + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + - `activation`: the activation function for the final layer + - `revnorm`: set to `true` to place the batch norm before the convolution + - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and + second convolution + - `bias`: a tuple of two booleans to specify whether to use bias for the first and second + convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and + `use_norm[1] == true`. + - `stride`: stride of the first convolution kernel + - `pad`: padding of the first convolution kernel + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) +""" +function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, + stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), + bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, + revnorm, use_norm = use_norm[1], stride, bias = bias[1], + groups = inplanes, kwargs...), + conv_norm((1, 1), inplanes, outplanes, activation; eps, + revnorm, use_norm = use_norm[2], bias = bias[2])) +end + +# TODO add support for stochastic depth to mbconv and fused_mbconv +""" + mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Real} = nothing, + se_round_fn = x -> round(Int, x), norm_layer = BatchNorm, kwargs...) + +Create a basic inverted residual block for MobileNet and Efficient variants. +This is a sequence of layers: + + - a 1x1 convolution from `inplanes => explanes` followed by a (batch) normalisation layer + + - `activation` if `inplanes != explanes` + - a `kernel_size` depthwise separable convolution from `explanes => explanes` + - a (batch) normalisation layer + - a squeeze-and-excitation block (if `reduction != nothing`) from + `explanes => se_round_fn(explanes / reduction)` and back to `explanes` + - a 1x1 convolution from `explanes => outplanes` + - a (batch) normalisation layer + `activation` + +First introduced in the MobileNetv2 paper. +(See Fig. 3 in [reference](https://arxiv.org/abs/1801.04381v4).) + +# Arguments + + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps + - `explanes`: The number of expanded feature maps. This is the number of feature maps + after the first 1x1 convolution. + - `outplanes`: The number of output feature maps + - `activation`: The activation function for the first two convolution layer + - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 + - `reduction`: The reduction factor for the number of hidden feature maps + in a squeeze and excite layer (see [`squeeze_excite`](#)) + - `se_round_fn`: The function to round the number of reduced feature maps + in the squeeze and excite layer + - `norm_layer`: The normalization layer to use +""" +function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + reduction::Union{Nothing, Real} = nothing, + se_round_fn = x -> round(Int, x), norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" + layers = [] + # expand + if inplanes != explanes + append!(layers, + conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) + end + # depthwise + append!(layers, + conv_norm(kernel_size, explanes, explanes, activation; norm_layer, + stride, pad = SamePad(), groups = explanes)) + # squeeze-excite layer + if !isnothing(reduction) + push!(layers, + squeeze_excite(explanes; round_fn = se_round_fn, reduction, + activation, gate_activation = hardσ)) + end + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) + return Chain(layers...) +end + +""" + fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) + +Create a fused inverted residual block. + +This is a sequence of layers: + + - a `kernel_size` depthwise separable convolution from `explanes => explanes` + - a (batch) normalisation layer + - a 1x1 convolution from `explanes => outplanes` followed by a (batch) normalisation + layer + `activation` if `inplanes != explanes` + +Originally introduced by Google in [EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML](https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html). +Later used in the EfficientNetv2 paper. + +# Arguments + + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps + - `explanes`: The number of expanded feature maps + - `outplanes`: The number of output feature maps + - `activation`: The activation function for the first two convolution layer + - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 + - `norm_layer`: The normalization layer to use +""" +function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, + explanes::Integer, outplanes::Integer, activation = relu; + stride::Integer, norm_layer = BatchNorm) + @assert stride in [1, 2] "`stride` has to be 1 or 2 for `fused_mbconv`" + layers = [] + # fused expand + explanes = explanes == inplanes ? outplanes : explanes + append!(layers, + conv_norm(kernel_size, inplanes, explanes, activation; norm_layer, stride, + pad = SamePad())) + if explanes != inplanes + # project + append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) + end + return Chain(layers...) +end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 0e496097b..e6336de9c 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,3 +1,4 @@ +# TODO @theabhirath figure out consistent behaviour for dropout rates - 0.0 vs `nothing` """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; dropout_rate = 0., activation = gelu) @@ -45,40 +46,3 @@ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) - -""" - create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = 0.0, use_conv::Bool = false) - -Creates a classifier head to be used for models. - -# Arguments - - - `inplanes`: number of input feature maps - - `nclasses`: number of output classes - - `activation`: activation function to use - - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with - any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. - - `dropout_rate`: dropout rate used in the classifier head. - - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. -""" -function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = nothing) - # Decide whether to flatten the input or not - flatten_in_pool = !use_conv && pool_layer !== identity - if use_conv - @assert pool_layer === identity - "`pool_layer` must be identity if `use_conv` is true" - end - classifier = [] - flatten_in_pool ? push!(classifier, pool_layer, MLUtils.flatten) : - push!(classifier, pool_layer) - # Dropout is applied after the pooling layer - isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) - # Fully-connected layer - use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) : - push!(classifier, Dense(inplanes => nclasses, activation)) - return Chain(classifier...) -end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 0756225ba..044d61dbf 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -1,33 +1,44 @@ """ - squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, - activation = relu, gate_activation = sigmoid, norm_layer = identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + squeeze_excite(inplanes::Integer, squeeze_planes::Integer; + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) -Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. + squeeze_excite(inplanes::Integer; reduction::Real = 16, + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) + +Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets. # Arguments - `inplanes`: The number of input feature maps - - `reduction`: The reduction factor for the number of hidden feature maps - - `rd_divisor`: The divisor for the number of hidden feature maps. + - `squeeze_planes`: The number of feature maps in the intermediate layers. Alternatively, + specify the keyword arguments `reduction` and `rd_divisior`, which determine the number + of feature maps in the intermediate layers from the number of input feature maps as: + `squeeze_planes = _round_channels(inplanes ÷ reduction)`. (See [`_round_channels`](#) for details.) - `activation`: The activation function for the first convolution layer - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer """ -function squeeze_excite(inplanes::Integer; reduction::Integer = 16, - rd_divisor::Integer = 8, activation = relu, - gate_activation = sigmoid, norm_layer = planes -> identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) +# TODO look into a `get_norm_act` layer that will return a closure over the norm layer +# with the activation function passed in when the norm layer is not `identity` +function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; + norm_layer = planes -> identity, activation = relu, + gate_activation = sigmoid) layers = [AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes => rd_planes), - norm_layer(rd_planes), + Conv((1, 1), inplanes => squeeze_planes), + norm_layer(squeeze_planes), activation, - Conv((1, 1), rd_planes => inplanes), + Conv((1, 1), squeeze_planes => inplanes), norm_layer(inplanes), gate_activation] return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end +function squeeze_excite(inplanes::Integer; reduction::Real = 16, + round_fn = _round_channels, kwargs...) + return squeeze_excite(inplanes, round_fn(inplanes / reduction); kwargs...) +end """ effective_squeeze_excite(inplanes, gate_activation = sigmoid) @@ -42,6 +53,5 @@ Effective squeeze-and-excitation layer. """ function effective_squeeze_excite(inplanes::Integer; gate_activation = sigmoid) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes, inplanes), - gate_activation), .*) + Conv((1, 1), inplanes => inplanes, gate_activation)), .*) end diff --git a/src/utilities.jl b/src/utilities.jl index f5737831c..13b8ec385 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -3,10 +3,10 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants -function _round_channels(channels, divisor, min_value = divisor) +function _round_channels(channels::Number, divisor::Integer = 8, min_value::Integer = 0) new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) # Make sure that round down does not go down by more than 10% - return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels + return new_channels < 0.9 * channels ? new_channels + divisor : new_channels end """ @@ -66,13 +66,17 @@ function _maybe_big_show(io, model) end """ - linear_scheduler(drop_path_rate = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_rate = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_rate::Nothing; depth::Integer) -Returns the dropout rates for a given depth using the linear scaling rule. +Returns the dropout rates for a given depth using the linear scaling rule. If the +`drop_rate` is `nothing`, it returns a `Vector` of length `depth` with all values +equal to `nothing`. """ function linear_scheduler(drop_rate = 0.0; depth::Integer, start_value = 0.0) return LinRange(start_value, drop_rate, depth) end +linear_scheduler(drop_rate::Nothing; depth::Integer) = fill(drop_rate, depth) # Utility function for depth and configuration checks in models function _checkconfig(config, configs) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 099d00639..75bfb5b07 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -61,7 +61,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, Dropout(emb_dropout_rate), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout_rate), - (pool == :class) ? x -> x[:, 1, :] : seconddimmean), + pool == :class ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end diff --git a/test/convnets.jl b/test/convnets.jl index 6d7dab496..0c796e24c 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -161,7 +161,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] x = rand(Float32, r, r, 3, 1) @@ -177,6 +177,20 @@ end end end +@testset "EfficientNetv2" begin + @testset for config in [:small, :medium, :large] # :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) @@ -263,11 +277,11 @@ end end end -@testset "MobileNet" verbose = true begin +@testset "MobileNets (width = $width_mult)" for width_mult in [0.5, 0.75, 1, 1.3] @testset "MobileNetv1" begin - m = MobileNetv1() + m = MobileNetv1(width_mult) @test size(m(x_224)) == (1000, 1) - if MobileNetv1 in PRETRAINED_MODELS + if (MobileNetv1, width_mult) in PRETRAINED_MODELS @test acctest(MobileNetv1(pretrain = true)) else @test_throws ArgumentError MobileNetv1(pretrain = true) @@ -276,9 +290,9 @@ end end _gc() @testset "MobileNetv2" begin - m = MobileNetv2() + m = MobileNetv2(width_mult) @test size(m(x_224)) == (1000, 1) - if MobileNetv2 in PRETRAINED_MODELS + if (MobileNetv2, width_mult) in PRETRAINED_MODELS @test acctest(MobileNetv2(pretrain = true)) else @test_throws ArgumentError MobileNetv2(pretrain = true) @@ -288,9 +302,9 @@ end _gc() @testset "MobileNetv3" verbose = true begin @testset for config in [:small, :large] - m = MobileNetv3(config) + m = MobileNetv3(config; width_mult) @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config) in PRETRAINED_MODELS + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS @test acctest(MobileNetv3(config; pretrain = true)) else @test_throws ArgumentError MobileNetv3(config; pretrain = true) @@ -299,6 +313,19 @@ end _gc() end end + @testset "MNASNet" verbose = true begin + @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) + else + @test_throws ArgumentError MNASNet(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end end @testset "ConvNeXt" verbose = true begin