From cd0edef3957db44f37a12ea156ec49182bfcfec9 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 16 Jun 2022 18:16:13 +0530 Subject: [PATCH 01/64] Add `DropBlock` --- Project.toml | 2 + src/convnets/densenet.jl | 2 +- src/layers/Layers.jl | 8 ++-- src/layers/drop.jl | 61 ++++++++++++++++++++++++++++ src/layers/{mlp.jl => mlp-linear.jl} | 15 +++++++ src/layers/others.jl | 26 ------------ 6 files changed, 84 insertions(+), 30 deletions(-) create mode 100644 src/layers/drop.jl rename src/layers/{mlp.jl => mlp-linear.jl} (83%) delete mode 100644 src/layers/others.jl diff --git a/Project.toml b/Project.toml index be550c1ef..bd618e534 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,14 @@ version = "0.7.3" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 0b318dbf3..588d2ad22 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -100,7 +100,7 @@ Create a DenseNet model - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ -function densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000) +function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, nclasses = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; reduction = reduction, nclasses = nclasses) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 1034136f3..8b2b73059 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,8 +1,10 @@ module Layers using Flux -using Flux: outputsize, Zygote +using NNlib +using NNlibCUDA using Functors +using ChainRulesCore using Statistics using MLUtils @@ -10,10 +12,10 @@ include("../utilities.jl") include("attention.jl") include("embeddings.jl") -include("mlp.jl") +include("mlp-linear.jl") include("normalise.jl") include("conv.jl") -include("others.jl") +include("drop.jl") export MHAttention, PatchEmbedding, ViPosEmbedding, ClassTokens, diff --git a/src/layers/drop.jl b/src/layers/drop.jl new file mode 100644 index 000000000..93c120651 --- /dev/null +++ b/src/layers/drop.jl @@ -0,0 +1,61 @@ +""" + DropBlock(drop_prob = 0.1, block_size = 7) + +Implements DropBlock, a regularization method for convolutional networks. +([reference](https://arxiv.org/pdf/1810.12890.pdf)) +""" +struct DropBlock{F} + drop_prob::F + block_size::Integer +end +@functor DropBlock + +(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size) + +DropBlock(drop_prob = 0.1, block_size = 7) = DropBlock(drop_prob, block_size) + +function _dropblock_checks(x, drop_prob, T) + if !(T <: AbstractArray) + throw(ArgumentError("x must be an `AbstractArray`")) + end + if ndims(x) != 4 + throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) + end + @assert drop_prob < 0 || drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" +end +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, T) + +function dropblock(x::T, drop_prob, block_size::Integer) where {T} + _dropblock_checks(x, drop_prob, T) + if drop_prob == 0 + return x + end + return _dropblock(x, drop_prob, block_size) +end + +function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size) where {T} + gamma = drop_prob / (block_size ^ 2) + mask = rand_like(x, Float32, (size(x, 1), size(x, 2), size(x, 3))) + mask .<= gamma + block_mask = maxpool(reshape(mask, (size(mask)[1:3]..., 1)), (block_size, block_size); + pad = block_size ÷ 2, stride = (1, 1)) + if block_size % 2 == 0 + block_mask = block_mask[1:(end - 1), 1:(end - 1), :, :] + end + block_mask = 1 .- dropdims(block_mask; dims = 4) + out = (x .* reshape(block_mask, (size(block_mask)[1:3]..., 1))) * length(block_mask) / + sum(block_mask) + return out +end + +""" + DropPath(p) + +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. +([reference](https://arxiv.org/abs/1603.09382)) + +# Arguments + + - `p`: rate of Stochastic Depth. +""" +DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity diff --git a/src/layers/mlp.jl b/src/layers/mlp-linear.jl similarity index 83% rename from src/layers/mlp.jl rename to src/layers/mlp-linear.jl index 25ead874b..e282e2632 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp-linear.jl @@ -1,3 +1,18 @@ +""" + LayerScale(λ, planes::Integer) + +Creates a `Flux.Scale` layer that performs "`LayerScale`" +([reference](https://arxiv.org/abs/2103.17239)). + +# Arguments + + - `planes`: Size of channel dimension in the input. + - `λ`: initialisation value for the learnable diagonal matrix. +""" +function LayerScale(planes::Integer, λ) + return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity +end + """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; dropout = 0., activation = gelu) diff --git a/src/layers/others.jl b/src/layers/others.jl deleted file mode 100644 index 770bccebd..000000000 --- a/src/layers/others.jl +++ /dev/null @@ -1,26 +0,0 @@ -""" - LayerScale(λ, planes::Integer) - -Creates a `Flux.Scale` layer that performs "`LayerScale`" -([reference](https://arxiv.org/abs/2103.17239)). - -# Arguments - - - `planes`: Size of channel dimension in the input. - - `λ`: initialisation value for the learnable diagonal matrix. -""" -function LayerScale(planes::Integer, λ) - return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity -end - -""" - DropPath(p) - -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. -([reference](https://arxiv.org/abs/1603.09382)) - -# Arguments - - - `p`: rate of Stochastic Depth. -""" -DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity From 271b430567e396367b359b92d8a0cd51c1b5bed7 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 21 Jun 2022 15:21:12 +0530 Subject: [PATCH 02/64] Initial commit for new ResNet API --- docs/make.jl | 4 +- docs/serve.jl | 2 +- src/Metalhead.jl | 4 +- src/convnets/densenet.jl | 3 +- src/convnets/resnet.jl | 408 ++++++++++++++++----------------------- src/layers/Layers.jl | 3 +- src/layers/drop.jl | 45 +++-- src/layers/normalise.jl | 4 +- 8 files changed, 202 insertions(+), 271 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index db03f1d76..f5d29f7e9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Pkg -Pkg.develop(path = "..") +Pkg.develop(; path = "..") using Publish using Artifacts, LazyArtifacts @@ -13,5 +13,5 @@ p = Publish.Project(Metalhead) function build_and_deploy(label) rm(label; recursive = true, force = true) - deploy(Metalhead; root = "/Metalhead.jl", label = label) + return deploy(Metalhead; root = "/Metalhead.jl", label = label) end diff --git a/docs/serve.jl b/docs/serve.jl index 763e77e93..bf4a51179 100644 --- a/docs/serve.jl +++ b/docs/serve.jl @@ -1,6 +1,6 @@ using Pkg -Pkg.develop(path = "..") +Pkg.develop(; path = "..") using Revise using Publish diff --git a/src/Metalhead.jl b/src/Metalhead.jl index f391c0c66..9a60ad351 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -40,7 +40,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, +# ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -49,7 +49,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, +for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, # :ResNet, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 588d2ad22..374909bb1 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -100,7 +100,8 @@ Create a DenseNet model - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ -function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, nclasses = 1000) +function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, + nclasses = 1000) where {N} return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; reduction = reduction, nclasses = nclasses) end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index 53d1fd6e3..768697131 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -1,259 +1,185 @@ -""" - basicblock(inplanes, outplanes, downsample = false) - -Create a basic residual block -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input -""" -function basicblock(inplanes, outplanes, downsample = false) - stride = downsample ? 2 : 1 - return Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, - bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, - bias = false)...) +function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, + reduce_first = 1, dilation = 1, first_dilation = nothing, + act_layer = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = 1 + @assert cardinality==1 "BasicBlock only supports cardinality of 1" + @assert base_width==64 "BasicBlock does not support changing base width" + first_planes = planes ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, + dilation = first_dilation, bias = false), + norm_layer(first_planes)) + drop_block = drop_block === identity ? identity : drop_block + conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; stride, pad = dilation, + dilation = dilation, bias = false), + norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_path)), + act_layer) end -""" - bottleneck(inplanes, outplanes, downsample = false; stride = [1, (downsample ? 2 : 1), 1]) - -Create a bottleneck residual block -([reference](https://arxiv.org/abs/1512.03385v1)). The bottleneck is composed of -3 convolutional layers each with the given `stride`. -By default, `stride` implements ["ResNet v1.5"](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch) -which uses `stride == [1, 2, 1]` when `downsample == true`. -This version is standard across various ML frameworks. -The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead. - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input - - `stride`: a list of the stride of the 3 convolutional layers -""" -function bottleneck(inplanes, outplanes, downsample = false; - stride = [1, (downsample ? 2 : 1), 1]) - return Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], - bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, - bias = false)..., - conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], - bias = false)...) +function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, + reduce_first = 1, dilation = 1, first_dilation = nothing, + act_layer = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = 4 + width = floor(Int, planes * (base_width / 64)) * cardinality + first_planes = width ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), + norm_layer(first_planes)) + conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, + dilation = first_dilation, groups = cardinality, bias = false), + norm_layer(width)) + drop_block = drop_block === identity ? identity : drop_block() + conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_block, + act_layer, conv_bn3, drop_path)), + act_layer) end -""" - bottleneck_v1(inplanes, outplanes, downsample = false) - -Create a bottleneck residual block -([reference](https://arxiv.org/abs/1512.03385v1)). The bottleneck is composed of -3 convolutional layers with all a stride of 1 except the first convolutional -layer which has a stride of 2. - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input -""" -function bottleneck_v1(inplanes, outplanes, downsample = false) - return bottleneck(inplanes, outplanes, downsample; - stride = [(downsample ? 2 : 1), 1, 1]) +function drop_blocks(drop_prob = 0.0) + return [identity, identity, + drop_prob == 0.0 ? DropBlock(drop_prob, 5, 0.25) : identity, + drop_prob == 0.0 ? DropBlock(drop_prob, 3, 1.00) : identity] end -""" - resnet(block, residuals::NTuple{2, Any}, connection = addrelu; - channel_config, block_config, nclasses = 1000) - -Create a ResNet model -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - `residuals`: a 2-tuple of functions with input `(inplanes, outplanes, downsample=false)`, - each of which will return a function that will be used as a new "skip" path to match a residual block. - [`Metalhead.skip_identity`](#) and [`Metalhead.skip_projection`](#) can be used here. - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection = addrelu; - channel_config, block_config, nclasses = 1000) - inplanes = 64 - baseplanes = 64 - layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false)) - push!(layers, MaxPool((3, 3); stride = (2, 2), pad = (1, 1))) - for (i, nrepeats) in enumerate(block_config) - # output planes within a block - outplanes = baseplanes .* channel_config - # push first skip connection on using first residual - # downsample the residual path if this is the first repetition of a block - push!(layers, - Parallel(connection, block(inplanes, outplanes, i != 1), - residuals[i][1](inplanes, outplanes[end], i != 1))) - # push remaining skip connections on using second residual - inplanes = outplanes[end] - for _ in 2:nrepeats - push!(layers, - Parallel(connection, block(inplanes, outplanes, false), - residuals[i][2](inplanes, outplanes[end], false))) - inplanes = outplanes[end] - end - # next set of output plane base is doubled - baseplanes *= 2 - end - # next set of output plane base is doubled - baseplanes *= 2 - return Chain(Chain(layers), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(inplanes, nclasses))) +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) end -""" - resnet(block, shortcut_config::Symbol, connection = addrelu; - channel_config, block_config, nclasses = 1000) - -Create a ResNet model -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - - `shortcut_config`: the type of shortcut style (either `:A`, `:B`, or `:C`) - - + `:A`: uses a [`Metalhead.skip_identity`](#) for all residual blocks - + `:B`: uses a [`Metalhead.skip_projection`](#) for the first residual block - and [`Metalhead.skip_identity`](@) for the remaining residual blocks - + `:C`: uses a [`Metalhead.skip_projection`](#) for all residual blocks - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnet(block, shortcut_config::AbstractVector{<:Symbol}, args...; kwargs...) - shortcut_dict = Dict(:A => (skip_identity, skip_identity), - :B => (skip_projection, skip_identity), - :C => (skip_projection, skip_projection)) - if any(sc -> !haskey(shortcut_dict, sc), shortcut_config) - error("Unrecognized shortcut_config ($shortcut_config) passed to `resnet` (use only :A, :B, or :C).") +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = avg_pool_fn((2, 2); stride = avg_stride, pad) end - shortcut = [shortcut_dict[sc] for sc in shortcut_config] - return resnet(block, shortcut, args...; kwargs...) -end - -function resnet(block, shortcut_config::Symbol, args...; block_config, kwargs...) - return resnet(block, fill(shortcut_config, length(block_config)), args...; - block_config = block_config, kwargs...) -end - -function resnet(block, residuals::NTuple{2}, args...; kwargs...) - return resnet(block, [residuals], args...; kwargs...) -end - -const resnet_config = Dict(18 => (([1, 1], [2, 2, 2, 2], [:A, :B, :B, :B]), basicblock), - 34 => (([1, 1], [3, 4, 6, 3], [:A, :B, :B, :B]), basicblock), - 50 => (([1, 1, 4], [3, 4, 6, 3], [:B, :B, :B, :B]), bottleneck), - 101 => (([1, 1, 4], [3, 4, 23, 3], [:B, :B, :B, :B]), bottleneck), - 152 => (([1, 1, 4], [3, 8, 36, 3], [:B, :B, :B, :B]), bottleneck)) - -""" - ResNet(channel_config, block_config, shortcut_config; - block, connection = addrelu, nclasses = 1000) -Create a `ResNet` model -([reference](https://arxiv.org/abs/1512.03385v1)). -See also [`resnet`](#). - -# Arguments - - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `shortcut_config`: the type of shortcut style (either `:A`, `:B`, or `:C`). - `shortcut_config` can also be a vector of symbols if different shortcut styles are applied to - different residual blocks. - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `nclasses`: the number of output classes -""" -struct ResNet - layers::Any + return Chain(pool, + Conv((1, 1), in_channels => out_channels; stride = 1, pad = 0, + bias = false), + norm_layer(out_channels)) end -function ResNet(channel_config, block_config, shortcut_config; - block, connection = addrelu, nclasses = 1000) - layers = resnet(block, - shortcut_config, - connection; - channel_config = channel_config, - block_config = block_config, - nclasses = nclasses) - return ResNet(layers) +function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, + reduce_first = 1, output_stride = 32, + down_kernel_size = 1, avg_down = false, drop_block_rate = 0.0, + drop_path_rate = 0.0, kwargs...) + kwarg_dict = Dict(kwargs...) + stages = [] + net_block_idx = 1 + net_stride = 4 + dilation = prev_dilation = 1 + for (stage_idx, (planes, num_blocks, db)) in enumerate(zip(channels, block_repeats, + drop_blocks(drop_block_rate))) + stride = stage_idx == 1 ? 1 : 2 + if net_stride >= output_stride + dilation *= stride + stride = 1 + else + net_stride *= stride + end + downsample = identity + if stride != 1 || inplanes != planes * expansion + downsample = avg_down ? + downsample_avg(down_kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation = prev_dilation, + norm_layer = kwarg_dict[:norm_layer]) : + downsample_conv(down_kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation = prev_dilation, + norm_layer = kwarg_dict[:norm_layer]) + end + block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, + :drop_block => db, kwargs...) + blocks = [] + for block_idx in 1:num_blocks + downsample = block_idx == 1 ? downsample : identity + stride = block_idx == 1 ? stride : 1 + # stochastic depth linear decay rule + block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + push!(blocks, + block_fn(inplanes, planes; stride, downsample, + first_dilation = prev_dilation, + drop_path = DropPath(block_dpr), block_kwargs...)) + prev_dilation = dilation + inplanes = planes * expansion + net_block_idx += 1 + end + push!(stages, Chain(blocks...)) + end + return Chain(stages...) end -@functor ResNet - -(m::ResNet)(x) = m.layers(x) - -backbone(m::ResNet) = m.layers[1] -classifier(m::ResNet) = m.layers[2] - -""" - ResNet(depth = 50; pretrain = false, nclasses = 1000) - -Create a ResNet model with a specified depth -([reference](https://arxiv.org/abs/1512.03385v1)) -following [these modification](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch) -referred as ResNet v1.5. - -See also [`Metalhead.resnet`](#). - -# Arguments - - - `depth`: depth of the ResNet model. Options include (18, 34, 50, 101, 152). - - `nclasses`: the number of output classes - -For `ResNet(18)` and `ResNet(34)`, the parameter-free shortcut style (type `:A`) -is used in the first block and the three other blocks use type `:B` connection -(following the implementation in PyTorch). The published version of -`ResNet(18)` and `ResNet(34)` used type `:A` shortcuts for all four blocks. The -example below shows how to create a 18 or 34-layer `ResNet` using only type `:A` -shortcuts: - -```julia -using Metalhead - -resnet18 = ResNet([1, 1], [2, 2, 2, 2], :A; block = Metalhead.basicblock) +function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, + expansion = 1, + cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, + replace_stem_pool = false, reduce_first = 1, + down_kernel_size = (1, 1), avg_down = false, act_layer = relu, + norm_layer = BatchNorm, + drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, + block_kwargs...) + @assert output_stride in (8, 16, 32) + @assert stem_type in [:default, :deep, :deep_tiered] + # Stem + inplanes = stem_type == :deep ? stem_width * 2 : 64 + if stem_type == :deep + stem_channels = (stem_width, stem_width) + if stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + bias = false), + norm_layer(stem_channels[1]), + act_layer(), + Conv((3, 3), stem_channels[1] => stem_channels[1]; stride = 1, + pad = 1, bias = false), + norm_layer(stem_channels[2]), + act_layer(), + Conv((3, 3), stem_channels[2] => inplanes; stride = 1, pad = 1, + bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes) + act1 = act_layer + # Stem pooling + if replace_stem_pool + stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, + bias = false), + norm_layer(inplanes), + act_layer) + else + stempool = MaxPool((3, 3); stride = 2, pad = 1) + end + stem = Chain(conv1, bn1, act1, stempool) -resnet34 = ResNet([1, 1], [3, 4, 6, 3], :A; block = Metalhead.basicblock) -``` + # Feature Blocks + channels = [64, 128, 256, 512] + stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, + output_stride, reduce_first, avg_down, + down_kernel_size, act_layer, norm_layer, + drop_block_rate, drop_path_rate, block_kwargs...) -The bottleneck of the orginal ResNet model has a stride of 2 on the first -convolutional layer when downsampling (instead of the second convolutional layers -as in ResNet v1.5). The architecture of the orignal ResNet model can be obtained -as shown below: + # Head (Pooling and Classifier) + num_features = 512 * expansion + classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, + Dense(num_features, num_classes)) -```julia -resnet50_v1 = ResNet([1, 1, 4], [3, 4, 6, 3], :B; block = Metalhead.bottleneck_v1) -``` -""" -function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000) - @assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))" - config, block = resnet_config[depth] - model = ResNet(config...; block = block, nclasses = nclasses) - pretrain && loadpretrain!(model, string("resnet", depth)) - return model + return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 8b2b73059..6c417c077 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -24,5 +24,6 @@ export MHAttention, ChannelLayerNorm, prenorm, skip_identity, skip_projection, conv_bn, depthwise_sep_conv_bn, - invertedresidual, squeeze_excite + invertedresidual, squeeze_excite, + DropBlock end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 93c120651..b3f9a8719 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -7,45 +7,48 @@ Implements DropBlock, a regularization method for convolutional networks. struct DropBlock{F} drop_prob::F block_size::Integer + gamma_scale::F end @functor DropBlock -(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size) +(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size, m.gamma_scale) -DropBlock(drop_prob = 0.1, block_size = 7) = DropBlock(drop_prob, block_size) +function DropBlock(drop_prob = 0.1, block_size = 7, gamma_scale = 1.0) + return DropBlock(drop_prob, block_size, gamma_scale) +end -function _dropblock_checks(x, drop_prob, T) +function _dropblock_checks(x, drop_prob, gamma_scale, T) if !(T <: AbstractArray) throw(ArgumentError("x must be an `AbstractArray`")) end if ndims(x) != 4 throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) end - @assert drop_prob < 0 || drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" + @assert drop_prob < 0||drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" + @assert gamma_scale < 0||gamma_scale > 1 "gamma_scale must be between 0 and 1, got $gamma_scale" end -ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, T) +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, gamma_scale, T) -function dropblock(x::T, drop_prob, block_size::Integer) where {T} - _dropblock_checks(x, drop_prob, T) +function dropblock(x::T, drop_prob, block_size::Integer, gamma_scale) where {T} + _dropblock_checks(x, drop_prob, gamma_scale, T) if drop_prob == 0 return x end - return _dropblock(x, drop_prob, block_size) + return _dropblock(x, drop_prob, block_size, gamma_scale) end -function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size) where {T} - gamma = drop_prob / (block_size ^ 2) - mask = rand_like(x, Float32, (size(x, 1), size(x, 2), size(x, 3))) - mask .<= gamma - block_mask = maxpool(reshape(mask, (size(mask)[1:3]..., 1)), (block_size, block_size); - pad = block_size ÷ 2, stride = (1, 1)) - if block_size % 2 == 0 - block_mask = block_mask[1:(end - 1), 1:(end - 1), :, :] - end - block_mask = 1 .- dropdims(block_mask; dims = 4) - out = (x .* reshape(block_mask, (size(block_mask)[1:3]..., 1))) * length(block_mask) / - sum(block_mask) - return out +function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size, gamma_scale) where {T} + H, W, _, _ = size(x) + total_size = H * W + clipped_block_size = min(block_size, min(H, W)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size^2 / + ((W - block_size + 1) * (H - block_size + 1)) + block_mask = rand_like(x) .< gamma + block_mask = maxpool(convert(T, block_mask), (clipped_block_size, clipped_block_size); + stride = 1, padding = clipped_block_size ÷ 2) + block_mask = 1 .- block_mask + normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) + return x * block_mask * normalize_scale end """ diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 4f69dab03..2d5e6399a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -19,9 +19,9 @@ end @functor ChannelLayerNorm -(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) - function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) diag = Flux.Scale(1, 1, sz, λ) return ChannelLayerNorm(diag, ϵ) end + +(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) From 866dbcc81937f40e672bd65d4c73c90025be042a Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 22 Jun 2022 07:33:48 +0530 Subject: [PATCH 03/64] Cleanup --- src/convnets/inception.jl | 59 +++++++-------- src/convnets/resnet.jl | 146 ++++++++++++++++++-------------------- src/convnets/vgg.jl | 36 ++++------ src/layers/attention.jl | 16 ++--- src/layers/drop.jl | 5 +- src/layers/mlp-linear.jl | 20 +++--- src/other/mlpmixer.jl | 29 ++++---- src/vit-based/vit.jl | 29 ++++---- 8 files changed, 169 insertions(+), 171 deletions(-) diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index c3fd39f5e..e4106e957 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -279,7 +279,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) + inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -287,10 +287,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -313,12 +313,13 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) inceptionv4_c(), inceptionv4_c(), inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(1536, nclasses)) return Chain(body, head) end """ - Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -326,8 +327,8 @@ Creates an Inceptionv4 model. # Arguments - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -338,7 +339,7 @@ struct Inceptionv4 layers::Any end -function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) +function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionv4(; inchannels, dropout, nclasses) pretrain && loadpretrain!(layers, "Inceptionv4") return Inceptionv4(layers) @@ -419,18 +420,18 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) # Arguments - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -446,12 +447,13 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), conv_bn((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(1536, nclasses)) return Chain(body, head) end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -459,8 +461,8 @@ Creates an InceptionResNetv2 model. # Arguments - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -471,9 +473,9 @@ struct InceptionResNetv2 layers::Any end -function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, +function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionresnetv2(; inchannels, dropout, nclasses) + layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "InceptionResNetv2") return InceptionResNetv2(layers) end @@ -533,18 +535,18 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) + xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., conv_bn((3, 3), 32, 64; bias = false)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), @@ -554,7 +556,8 @@ function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) xception_block(728, 1024, 2; stride = 2, grow_at_start = false), depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(2048, nclasses)) return Chain(body, head) end @@ -563,7 +566,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -571,16 +574,16 @@ Creates an Xception model. # Arguments - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning `Xception` does not currently support pretrained weights. """ -function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - layers = xception(; inchannels, dropout, nclasses) +function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "xception") return Xception(layers) end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index 768697131..875421360 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -1,9 +1,42 @@ +function drop_blocks(drop_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_prob, 5, 0.25), + DropBlock(drop_prob, 3, 1.00), + ] +end + +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) +end + +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = avg_pool_fn((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), in_channels => out_channels; bias = false), + norm_layer(out_channels)) +end + function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, - reduce_first = 1, dilation = 1, first_dilation = nothing, - act_layer = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity) - expansion = 1 + expansion = expansion_factor(basicblock) @assert cardinality==1 "BasicBlock only supports cardinality of 1" @assert base_width==64 "BasicBlock does not support changing base width" first_planes = planes ÷ reduce_first @@ -17,16 +50,16 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina dilation = dilation, bias = false), norm_layer(outplanes)) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_path)), - act_layer) + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + activation) end +expansion_factor(::typeof(basicblock)) = 1 function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, - reduce_first = 1, dilation = 1, first_dilation = nothing, - act_layer = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity) - expansion = 4 + expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first outplanes = planes * expansion @@ -39,55 +72,25 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina drop_block = drop_block === identity ? identity : drop_block() conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_block, - act_layer, conv_bn3, drop_path)), - act_layer) -end - -function drop_blocks(drop_prob = 0.0) - return [identity, identity, - drop_prob == 0.0 ? DropBlock(drop_prob, 5, 0.25) : identity, - drop_prob == 0.0 ? DropBlock(drop_prob, 3, 1.00) : identity] + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_block, + activation, conv_bn3, drop_path)), + activation) end +expansion_factor(::typeof(bottleneck)) = 4 -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(out_channels)) -end - -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = avg_pool_fn((2, 2); stride = avg_stride, pad) - end - - return Chain(pool, - Conv((1, 1), in_channels => out_channels; stride = 1, pad = 0, - bias = false), - norm_layer(out_channels)) -end - -function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, - reduce_first = 1, output_stride = 32, - down_kernel_size = 1, avg_down = false, drop_block_rate = 0.0, - drop_path_rate = 0.0, kwargs...) +function make_blocks(block_fn, channels, block_repeats, inplanes; + reduce_first = 1, output_stride = 32, down_kernel_size = 1, + avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, + kwargs...) + expansion = expansion_factor(block_fn) kwarg_dict = Dict(kwargs...) stages = [] net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 - for (stage_idx, (planes, num_blocks, db)) in enumerate(zip(channels, block_repeats, - drop_blocks(drop_block_rate))) + for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, + block_repeats, + drop_blocks(drop_block_rate))) stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride dilation *= stride @@ -95,6 +98,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, else net_stride *= stride end + # first block needs to be handled differently for downsampling downsample = identity if stride != 1 || inplanes != planes * expansion downsample = avg_down ? @@ -106,7 +110,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, norm_layer = kwarg_dict[:norm_layer]) end block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => db, kwargs...) + :drop_block => drop_block, kwargs...) blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity @@ -127,15 +131,13 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, end function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, - expansion = 1, cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, - down_kernel_size = (1, 1), avg_down = false, act_layer = relu, - norm_layer = BatchNorm, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, block_kwargs...) - @assert output_stride in (8, 16, 32) - @assert stem_type in [:default, :deep, :deep_tiered] + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Stem inplanes = stem_type == :deep ? stem_width * 2 : 64 if stem_type == :deep @@ -145,38 +147,32 @@ function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride end conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, bias = false), - norm_layer(stem_channels[1]), - act_layer(), - Conv((3, 3), stem_channels[1] => stem_channels[1]; stride = 1, - pad = 1, bias = false), - norm_layer(stem_channels[2]), - act_layer(), - Conv((3, 3), stem_channels[2] => inplanes; stride = 1, pad = 1, - bias = false)) + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) end - bn1 = norm_layer(inplanes) - act1 = act_layer + bn1 = norm_layer(inplanes, activation) # Stem pooling if replace_stem_pool stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, bias = false), - norm_layer(inplanes), - act_layer) + norm_layer(inplanes, activation)) else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - stem = Chain(conv1, bn1, act1, stempool) - + stem = Chain(conv1, bn1, stempool) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, output_stride, reduce_first, avg_down, - down_kernel_size, act_layer, norm_layer, + down_kernel_size, activation, norm_layer, drop_block_rate, drop_path_rate, block_kwargs...) - # Head (Pooling and Classifier) + expansion = expansion_factor(block) num_features = 512 * expansion classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, Dense(num_features, num_classes)) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 56975a124..15560de7c 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -52,7 +52,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -63,19 +63,19 @@ Create VGG classifier (fully connected) layers the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `dropout`: the dropout level between each fully connected layer + - `drop_rate`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +function vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) return Chain(MLUtils.flatten, Dense(Int(prod(imsize)), fcsize, relu), - Dropout(dropout), + Dropout(drop_rate), Dense(fcsize, fcsize, relu), - Dropout(dropout), + Dropout(drop_rate), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -90,12 +90,12 @@ Create a VGG model - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `drop_rate`: dropout level between fully connected layers """ -function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) +function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + class = vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) return Chain(Chain(conv), class) end @@ -114,7 +114,7 @@ struct VGG end """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -126,17 +126,11 @@ Construct a VGG model with the specified input image size. Typically, the image - `nclasses`::Integer : number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `drop_rate`: dropout level between fully connected layers """ -function VGG(imsize::Dims{2}; - config, inchannels, batchnorm = false, nclasses, fcsize, dropout) - layers = vgg(imsize; config = config, - inchannels = inchannels, - batchnorm = batchnorm, - nclasses = nclasses, - fcsize = fcsize, - dropout = dropout) - +function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, + drop_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, drop_rate) return VGG(layers) end @@ -165,7 +159,7 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses batchnorm = batchnorm, nclasses = nclasses, fcsize = 4096, - dropout = 0.5) + drop_rate = 0.5) if pretrain && !batchnorm loadpretrain!(model, string("vgg", depth)) elseif pretrain diff --git a/src/layers/attention.jl b/src/layers/attention.jl index a1244a033..b6e7b7678 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -7,18 +7,18 @@ Multi-head self-attention layer. - `nheads`: Number of heads - `qkv_layer`: layer to be used for getting the query, key and value - - `attn_drop`: dropout rate after the self-attention layer + - `attn_drop_rate`: dropout rate after the self-attention layer - `projection`: projection layer to be used after self-attention """ struct MHAttention{P, Q, R} nheads::Int qkv_layer::P - attn_drop::Q + attn_drop_rate::Q projection::R end """ - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop = 0., proj_drop = 0.) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.) Multi-head self-attention layer. @@ -27,15 +27,15 @@ Multi-head self-attention layer. - `planes`: number of input channels - `nheads`: number of heads - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_drop`: dropout rate after the self-attention layer - - `proj_drop`: dropout rate after the projection layer + - `attn_drop_rate`: dropout rate after the self-attention layer + - `proj_drop_rate`: dropout rate after the projection layer """ function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_drop = 0.0, proj_drop = 0.0) + attn_drop_rate = 0.0, proj_drop_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_drop) - proj = Chain(Dense(planes, planes), Dropout(proj_drop)) + attn_drop = Dropout(attn_drop_rate) + proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) return MHAttention(nheads, qkv_layer, attn_drop, proj) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b3f9a8719..8e6202085 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -54,11 +54,12 @@ end """ DropPath(p) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0 and +`identity` otherwise. ([reference](https://arxiv.org/abs/1603.09382)) # Arguments - `p`: rate of Stochastic Depth. """ -DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity +DropPath(p) = p > 0 ? Dropout(p; dims = 4) : identity diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp-linear.jl index e282e2632..550c2ad22 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp-linear.jl @@ -15,7 +15,7 @@ end """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0., activation = gelu) + drop_rate =0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -24,18 +24,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `drop_rate`: Dropout rate. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), - Dense(hidden_planes, outplanes), Dropout(dropout)) + drop_rate = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(drop_rate), + Dense(hidden_planes, outplanes), Dropout(drop_rate)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0., activation = gelu) + outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -46,16 +46,16 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `drop_rate`: Dropout rate. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0.0, activation = gelu) + outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(dropout), + Dropout(drop_rate), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(dropout)) + Dropout(drop_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 942abc823..ed4c47af3 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0., drop_path_rate = 0., activation = gelu) + drop_rate =0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -12,20 +12,22 @@ Creates a feedforward block for the MLPMixer architecture. - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP and/or the channel mixing MLP as a ratio to the number of planes in the block. - `mlp_layer`: the MLP layer to use in the block - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu) + drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] return Chain(SkipConnection(Chain(LayerNorm(planes), swapdims((2, 1, 3)), - mlp_layer(npatches, tokenplanes; activation, dropout), + mlp_layer(npatches, tokenplanes; activation, + drop_rate), swapdims((2, 1, 3)), DropPath(drop_path_rate)), +), SkipConnection(Chain(LayerNorm(planes), - mlp_layer(planes, channelplanes; activation, dropout), + mlp_layer(planes, channelplanes; activation, + drop_rate), DropPath(drop_path_rate)), +)) end @@ -113,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; dropout = 0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; drop_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. @@ -126,13 +128,13 @@ Creates a block for the ResMixer architecture. - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `mlp_layer`: the MLP block to use - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks - `λ`: initialisation constant for the LayerScale """ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) + drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) return Chain(SkipConnection(Chain(Flux.Scale(planes), swapdims((2, 1, 3)), Dense(npatches, npatches), @@ -140,7 +142,7 @@ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, LayerScale(planes, λ), DropPath(drop_path_rate)), +), SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); dropout, + mlp_layer(planes, Int(mlp_ratio * planes); drop_rate, activation), LayerScale(planes, λ), DropPath(drop_path_rate)), +)) @@ -230,7 +232,7 @@ end """ spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, dropout = 0.0, drop_path_rate = 0., + norm_layer = LayerNorm, drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. @@ -243,18 +245,19 @@ Creates a feedforward block based on the gMLP model architecture described in th - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `norm_layer`: the normalisation layer to use - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, dropout = 0.0, + mlp_layer = gated_mlp_block, drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) channelplanes = Int(mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), - mlp_layer(sgu, planes, channelplanes; activation, dropout), + mlp_layer(sgu, planes, channelplanes; activation, + drop_rate), DropPath(drop_path_rate)), +) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 012bfef9d..686ddc4d5 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate =0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,23 +10,24 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `dropout`: dropout rate + - `drop_rate`: dropout rate """ -function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.0) +function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, - MHAttention(planes, nheads; attn_drop = dropout, - proj_drop = dropout)), +), + MHAttention(planes, nheads; + attn_drop_rate = drop_rate, + proj_drop_rate = drop_rate)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - dropout)), +)) + drop_rate)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + emb_drop_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -40,22 +41,23 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `dropout`: dropout rate + - `drop_rate`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), ViPosEmbedding(embedplanes, npatches + 1), - Dropout(emb_dropout), - transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout), + Dropout(emb_drop_rate), + transformer_encoder(embedplanes, depth, nheads; mlp_ratio, + drop_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end @@ -98,7 +100,6 @@ function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" kwargs = vit_configs[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) - return ViT(layers) end From a038ff836148eea7b7919dd4275053a74b500be9 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 23 Jun 2022 11:49:46 +0530 Subject: [PATCH 04/64] Get some stuff to work 1. Some docs 2. Basic tests for ResNet and ResNeXt now pass --- src/Metalhead.jl | 7 +- src/convnets/resne(x)t.jl | 377 ++++++++++++++++++++++++++++++++++++++ src/convnets/resnet.jl | 181 ------------------ src/convnets/resnext.jl | 126 ------------- src/layers/drop.jl | 43 +++-- test/convnets.jl | 19 +- 6 files changed, 407 insertions(+), 346 deletions(-) create mode 100644 src/convnets/resne(x)t.jl delete mode 100644 src/convnets/resnet.jl delete mode 100644 src/convnets/resnext.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 9a60ad351..5463c64de 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,8 +22,7 @@ include("convnets/alexnet.jl") include("convnets/vgg.jl") include("convnets/inception.jl") include("convnets/googlenet.jl") -include("convnets/resnet.jl") -include("convnets/resnext.jl") +include("convnets/resne(x)t.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") @@ -40,7 +39,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, -# ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -49,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, # :ResNet, +for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, :ResNet, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl new file mode 100644 index 000000000..abf7193b1 --- /dev/null +++ b/src/convnets/resne(x)t.jl @@ -0,0 +1,377 @@ +# returns `DropBlock`s for each block of the ResNet +function _drop_blocks(drop_block_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_block_prob, 5, 0.25), + DropBlock(drop_block_prob, 3, 1.00), + ] +end + +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) +end + +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = MeanPool((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), in_channels => out_channels; bias = false), + norm_layer(out_channels)) +end + +function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = expansion_factor(basicblock) + @assert cardinality==1 "`basicblock` only supports cardinality of 1" + @assert base_width==64 "`basicblock` does not support changing base width" + first_planes = planes ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, + dilation = first_dilation, bias = false), + norm_layer(first_planes)) + drop_block = drop_block + conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = dilation, + dilation = dilation, bias = false), + norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + activation) +end +expansion_factor(::typeof(basicblock)) = 1 + +function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = expansion_factor(bottleneck) + width = floor(Int, planes * (base_width / 64)) * cardinality + first_planes = width ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), + norm_layer(first_planes, activation)) + conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, + dilation = first_dilation, groups = cardinality, bias = false), + norm_layer(width)) + conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, + drop_path)), + activation) +end +expansion_factor(::typeof(bottleneck)) = 4 + +# Makes the main stages of the ResNet model. This is an internal function and should not be +# used by end-users. `block_fn` is a function that returns a single block of the ResNet. +# See `basicblock` and `bottleneck` for examples. A block must define a function +# `expansion(::typeof(block))` that returns the expansion factor of the block. +function _make_blocks(block_fn, channels, block_repeats, inplanes; + reduce_first = 1, output_stride = 32, down_kernel_size = (1, 1), + avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, + kwargs...) + expansion = expansion_factor(block_fn) + kwarg_dict = Dict(kwargs...) + stages = [] + net_block_idx = 1 + net_stride = 4 + dilation = prev_dilation = 1 + for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, + block_repeats, + _drop_blocks(drop_block_rate))) + # Stride calculations for each stage + stride = stage_idx == 1 ? 1 : 2 + if net_stride >= output_stride + dilation *= stride + stride = 1 + else + net_stride *= stride + end + # use average pooling for projection skip connection between stages/downsample. + downsample = identity + if stride != 1 || inplanes != planes * expansion + downsample_fn = avg_down ? downsample_avg : downsample_conv + downsample = downsample_fn(down_kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation = dilation, + norm_layer = kwarg_dict[:norm_layer]) + end + # arguments to be passed into the block function + block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, + :drop_block => drop_block, kwargs...) + blocks = [] + for block_idx in 1:num_blocks + downsample = block_idx == 1 ? downsample : identity + stride = block_idx == 1 ? stride : 1 + # stochastic depth linear decay rule + block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + push!(blocks, + block_fn(inplanes, planes; stride, downsample, + first_dilation = prev_dilation, + drop_path = DropPath(block_dpr), block_kwargs...)) + prev_dilation = dilation + inplanes = planes * expansion + net_block_idx += 1 + end + push!(stages, Chain(blocks...)) + end + return Chain(stages...) +end + +""" + resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, + cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, + drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, + block_kwargs...) + +Creates the layers of a ResNe(X)t model. If you are an end-user, you should probably use +[ResNet](@ref) instead and pass in the parameters you want to modify as optional parameters +there. + +# Arguments: + + - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for + example. + + - `layers`: A list of integers representing the number of blocks in each stage. + - `nclasses`: The number of output classes. The default value is 1000. + - `inchannels`: The number of input channels to the model. The default value is 3. + - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + This is used for [ResNeXt](@ref)-like models. The default value is 1. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels: `planes * base_width / 64 * cardinality`. + The default value is 64. + - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. + - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: + + + `:default` - a single 7x7 convolution layer with a width of `stem_width` + + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` + + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. + The default value is `:default`. + - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a + convolution layer. The default value is false. + - `reduce_first`: Reduction factor for first convolution output width of residual blocks, + Default is 1 for all architectures except SE-Nets, where it is 2. + - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the + skip connection. The default value is (1, 1) for all architectures except + SE-Nets, where it is (3, 3). + - `avg_down`: Use average pooling for projection skip connection between stages/downsample. + - `activation`: The activation function to use. The default value is `relu`. + - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. + - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `DropPath`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. + - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. + +If you are an end-user trying to tweak the ResNet model, note that there is no guarantee that +all combinations of parameters will work. In particular, tweaking `block_kwargs` is not +advised unless you know what you are doing. +""" +function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, + cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, + drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, + block_kwargs...) + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + # Stem + inplanes = stem_type == :deep ? stem_width * 2 : 64 + if stem_type == :deep + stem_channels = (stem_width, stem_width) + if stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + bias = false), + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes, activation) + # Stem pooling + if replace_stem_pool + stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, + bias = false), + norm_layer(inplanes, activation)) + else + stempool = MaxPool((3, 3); stride = 2, pad = 1) + end + stem = Chain(conv1, bn1, stempool) + # Feature Blocks + channels = [64, 128, 256, 512] + stage_blocks = _make_blocks(block, channels, layers, inplanes; cardinality, base_width, + output_stride, reduce_first, avg_down, + down_kernel_size, activation, norm_layer, + drop_block_rate, drop_path_rate, block_kwargs...) + # Head (Pooling and Classifier) + expansion = expansion_factor(block) + num_features = 512 * expansion + classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, + Dense(num_features, nclasses)) + return Chain(Chain(stem, stage_blocks), classifier) +end + +const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) + +""" + ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) + +Creates a ResNet model. +((reference)[https://arxiv.org/abs/1512.03385]) + +# Arguments: + + - `depth`: The depth of the `ResNet` model. Must be one of `[18, 34, 50, 101, 152]`. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. + - `nclasses`: The number of output classes. The default value is 1000. + +Apart from these, the model can also take any of the following optional arguments: + + - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for + example. + + - `layers`: A list of integers representing the number of blocks in each stage. + - `inchannels`: The number of input channels to the model. The default value is 3. + - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + This is used for [ResNeXt](@ref)-like models. The default value is 1. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels: `planes * base_width / 64 * cardinality`. + The default value is 64. + - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. + - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: + + + `:default` - a single 7x7 convolution layer with a width of `stem_width` + + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` + + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. + The default value is `:default`. + - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a + convolution layer. The default value is false. + - `reduce_first`: Reduction factor for first convolution output width of residual blocks, + Default is 1 for all architectures except SE-Nets, where it is 2. + - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the + skip connection. The default value is (1, 1) for all architectures except + SE-Nets, where it is (3, 3). + - `avg_down`: Use average pooling for projection skip connection between stages/downsample. + - `activation`: The activation function to use. The default value is `relu`. + - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. + - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. + - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. + +See also [`resnet`](@ref) for more details. + +!!! warning + + Pretrained models are not supported for all parameter combinations of `ResNet`. +""" +struct ResNet + layers::Any +end +@functor ResNet + +function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) + @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + model = resnet(resnet_config[depth]...; nclasses, kwargs...) + pretrain && loadpretrain!(model, string("resnet", depth)) + return model +end + +""" + ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, nclasses = 1000, + kwargs...) + +Creates a ResNeXt model. +((reference)[https://arxiv.org/abs/1611.05431]) + +# Arguments: + + - `depth`: The depth of the `ResNeXt` model. Must be one of `[50, 101, 152]`. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + of the `ResNeXt` mode. The default value is 4. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels. The default value is 32. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. + - `nclasses`: The number of output classes. The default value is 1000. + +Apart from these, the model can also take any of the following optional arguments: + + - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for + example. + + - `layers`: A list of integers representing the number of blocks in each stage. + - `inchannels`: The number of input channels to the model. The default value is 3. + - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + This is used for [ResNeXt](@ref)-like models. The default value is 1. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels: `planes * base_width / 64 * cardinality`. + The default value is 64. + - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. + - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: + + + `:default` - a single 7x7 convolution layer with a width of `stem_width` + + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` + + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. + The default value is `:default`. + - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a + convolution layer. The default value is false. + - `reduce_first`: Reduction factor for first convolution output width of residual blocks, + Default is 1 for all architectures except SE-Nets, where it is 2. + - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the + skip connection. The default value is (1, 1) for all architectures except + SE-Nets, where it is (3, 3). + - `avg_down`: Use average pooling for projection skip connection between stages/downsample. + - `activation`: The activation function to use. The default value is `relu`. + - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. + - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. + - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. + +See also [`resnet`](@ref) for more details. + +!!! warning + + Pretrained models are not currently supported for `ResNeXt`. +""" +struct ResNeXt + layers::Any +end +@functor ResNeXt + +function ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, + nclasses = 1000, + kwargs...) + @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" + model = resnet(resnet_config[depth]...; cardinality, base_width, nclasses, kwargs...) + pretrain && + loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + return model +end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl deleted file mode 100644 index 875421360..000000000 --- a/src/convnets/resnet.jl +++ /dev/null @@ -1,181 +0,0 @@ -function drop_blocks(drop_prob = 0.0) - return [ - identity, - identity, - DropBlock(drop_prob, 5, 0.25), - DropBlock(drop_prob, 3, 1.00), - ] -end - -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(out_channels)) -end - -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = avg_pool_fn((2, 2); stride = avg_stride, pad) - end - return Chain(pool, - Conv((1, 1), in_channels => out_channels; bias = false), - norm_layer(out_channels)) -end - -function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) - expansion = expansion_factor(basicblock) - @assert cardinality==1 "BasicBlock only supports cardinality of 1" - @assert base_width==64 "BasicBlock does not support changing base width" - first_planes = planes ÷ reduce_first - outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation - conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, - dilation = first_dilation, bias = false), - norm_layer(first_planes)) - drop_block = drop_block === identity ? identity : drop_block - conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; stride, pad = dilation, - dilation = dilation, bias = false), - norm_layer(outplanes)) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), - activation) -end -expansion_factor(::typeof(basicblock)) = 1 - -function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) - expansion = expansion_factor(bottleneck) - width = floor(Int, planes * (base_width / 64)) * cardinality - first_planes = width ÷ reduce_first - outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation - conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), - norm_layer(first_planes)) - conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, - dilation = first_dilation, groups = cardinality, bias = false), - norm_layer(width)) - drop_block = drop_block === identity ? identity : drop_block() - conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, drop_block, - activation, conv_bn3, drop_path)), - activation) -end -expansion_factor(::typeof(bottleneck)) = 4 - -function make_blocks(block_fn, channels, block_repeats, inplanes; - reduce_first = 1, output_stride = 32, down_kernel_size = 1, - avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - kwargs...) - expansion = expansion_factor(block_fn) - kwarg_dict = Dict(kwargs...) - stages = [] - net_block_idx = 1 - net_stride = 4 - dilation = prev_dilation = 1 - for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, - block_repeats, - drop_blocks(drop_block_rate))) - stride = stage_idx == 1 ? 1 : 2 - if net_stride >= output_stride - dilation *= stride - stride = 1 - else - net_stride *= stride - end - # first block needs to be handled differently for downsampling - downsample = identity - if stride != 1 || inplanes != planes * expansion - downsample = avg_down ? - downsample_avg(down_kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation = prev_dilation, - norm_layer = kwarg_dict[:norm_layer]) : - downsample_conv(down_kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation = prev_dilation, - norm_layer = kwarg_dict[:norm_layer]) - end - block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => drop_block, kwargs...) - blocks = [] - for block_idx in 1:num_blocks - downsample = block_idx == 1 ? downsample : identity - stride = block_idx == 1 ? stride : 1 - # stochastic depth linear decay rule - block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) - push!(blocks, - block_fn(inplanes, planes; stride, downsample, - first_dilation = prev_dilation, - drop_path = DropPath(block_dpr), block_kwargs...)) - prev_dilation = dilation - inplanes = planes * expansion - net_block_idx += 1 - end - push!(stages, Chain(blocks...)) - end - return Chain(stages...) -end - -function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, - cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), - avg_down = false, activation = relu, norm_layer = BatchNorm, - drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, - block_kwargs...) - @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" - @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" - # Stem - inplanes = stem_type == :deep ? stem_width * 2 : 64 - if stem_type == :deep - stem_channels = (stem_width, stem_width) - if stem_type == :deep_tiered - stem_channels = (3 * (stem_width ÷ 4), stem_width) - end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, - bias = false), - norm_layer(stem_channels[1], activation), - Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, - bias = false), - norm_layer(stem_channels[2], activation), - Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) - else - conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) - end - bn1 = norm_layer(inplanes, activation) - # Stem pooling - if replace_stem_pool - stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, - bias = false), - norm_layer(inplanes, activation)) - else - stempool = MaxPool((3, 3); stride = 2, pad = 1) - end - stem = Chain(conv1, bn1, stempool) - # Feature Blocks - channels = [64, 128, 256, 512] - stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, - output_stride, reduce_first, avg_down, - down_kernel_size, activation, norm_layer, - drop_block_rate, drop_path_rate, block_kwargs...) - # Head (Pooling and Classifier) - expansion = expansion_factor(block) - num_features = 512 * expansion - classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, - Dense(num_features, num_classes)) - - return Chain(Chain(stem, stage_blocks), classifier) -end diff --git a/src/convnets/resnext.jl b/src/convnets/resnext.jl deleted file mode 100644 index fc00bb180..000000000 --- a/src/convnets/resnext.jl +++ /dev/null @@ -1,126 +0,0 @@ -""" - resnextblock(inplanes, outplanes, cardinality, width, downsample = false) - -Create a basic residual block as defined in the paper for ResNeXt -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `downsample`: set to `true` to downsample the input -""" -function resnextblock(inplanes, outplanes, cardinality, width, downsample = false) - stride = downsample ? 2 : 1 - hidden_channels = cardinality * width - return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false)..., - conv_bn((3, 3), hidden_channels, hidden_channels; - stride = stride, pad = 1, bias = false, groups = cardinality)..., - conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false)...) -end - -""" - resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. relu(x) + relu(y); - block_config, nclasses = 1000) - -Create a ResNeXt model -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments - - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `widen_factor`: the factor by which the width of the bottleneck is increased after each stage - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnext(cardinality, width, widen_factor = 2, - connection = (x, y) -> @. relu(x) + relu(y); - block_config, nclasses = 1000) - inplanes = 64 - baseplanes = 128 - layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3))) - push!(layers, MaxPool((3, 3); stride = (2, 2), pad = (1, 1))) - for (i, nrepeats) in enumerate(block_config) - # output planes within a block - outplanes = baseplanes * widen_factor - # push first skip connection on using first residual - # downsample the residual path if this is the first repetition of a block - push!(layers, - Parallel(connection, - resnextblock(inplanes, outplanes, cardinality, width, i != 1), - skip_projection(inplanes, outplanes, i != 1))) - # push remaining skip connections on using second residual - inplanes = outplanes - for _ in 2:nrepeats - push!(layers, - Parallel(connection, - resnextblock(inplanes, outplanes, cardinality, width, false), - skip_identity(inplanes, outplanes, false))) - end - baseplanes = outplanes - # double width after every cluster of blocks - width *= widen_factor - end - return Chain(Chain(layers), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(inplanes, nclasses))) -end - -""" - ResNeXt(cardinality, width; block_config, nclasses = 1000) - -Create a ResNeXt model -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments - - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -struct ResNeXt - layers::Any -end - -function ResNeXt(cardinality, width; block_config, nclasses = 1000) - layers = resnext(cardinality, width; block_config, nclasses) - return ResNeXt(layers) -end - -@functor ResNeXt - -(m::ResNeXt)(x) = m.layers(x) - -backbone(m::ResNeXt) = m.layers[1] -classifier(m::ResNeXt) = m.layers[2] - -const resnext_config = Dict(50 => (3, 4, 6, 3), - 101 => (3, 4, 23, 3), - 152 => (3, 8, 36, 3)) - -""" - ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000) - -Create a ResNeXt model with specified configuration. Currently supported values for `config` are (50, 101). -([reference](https://arxiv.org/abs/1611.05431)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -!!! warning - - `ResNeXt` does not currently support pretrained weights. - -See also [`Metalhead.resnext`](#). -""" -function ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, - nclasses = 1000) - @assert config in keys(resnext_config) "`config` must be one of $(sort(collect(keys(resnext_config))))" - model = ResNeXt(cardinality, width; block_config = resnext_config[config], nclasses) - pretrain && loadpretrain!(model, string("ResNeXt", config)) - return model -end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 8e6202085..df66a3e7f 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,54 +1,53 @@ """ - DropBlock(drop_prob = 0.1, block_size = 7) + DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) Implements DropBlock, a regularization method for convolutional networks. ([reference](https://arxiv.org/pdf/1810.12890.pdf)) """ struct DropBlock{F} - drop_prob::F + drop_block_prob::F block_size::Integer gamma_scale::F end @functor DropBlock -(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size, m.gamma_scale) +(m::DropBlock)(x) = dropblock(x, m.drop_block_prob, m.block_size, m.gamma_scale) -function DropBlock(drop_prob = 0.1, block_size = 7, gamma_scale = 1.0) - return DropBlock(drop_prob, block_size, gamma_scale) +function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) + if drop_block_prob == 0.0 + return identity + end + @assert drop_block_prob < 0 || drop_block_prob > 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert gamma_scale < 0 || gamma_scale > 1 + "gamma_scale must be between 0 and 1, got $gamma_scale" + return DropBlock(drop_block_prob, block_size, gamma_scale) end -function _dropblock_checks(x, drop_prob, gamma_scale, T) +function _dropblock_checks(x::T) where {T} if !(T <: AbstractArray) throw(ArgumentError("x must be an `AbstractArray`")) end if ndims(x) != 4 throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) end - @assert drop_prob < 0||drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" - @assert gamma_scale < 0||gamma_scale > 1 "gamma_scale must be between 0 and 1, got $gamma_scale" -end -ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, gamma_scale, T) - -function dropblock(x::T, drop_prob, block_size::Integer, gamma_scale) where {T} - _dropblock_checks(x, drop_prob, gamma_scale, T) - if drop_prob == 0 - return x - end - return _dropblock(x, drop_prob, block_size, gamma_scale) end +ChainRulesCore.@non_differentiable _dropblock_checks(x) -function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size, gamma_scale) where {T} +function dropblock(x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale) where {T} + _dropblock_checks(x) H, W, _, _ = size(x) total_size = H * W clipped_block_size = min(block_size, min(H, W)) - gamma = gamma_scale * drop_prob * total_size / clipped_block_size^2 / + gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / ((W - block_size + 1) * (H - block_size + 1)) block_mask = rand_like(x) .< gamma - block_mask = maxpool(convert(T, block_mask), (clipped_block_size, clipped_block_size); - stride = 1, padding = clipped_block_size ÷ 2) + block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) block_mask = 1 .- block_mask normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) - return x * block_mask * normalize_scale + return x .* block_mask .* normalize_scale end """ diff --git a/test/convnets.jl b/test/convnets.jl index 9d1645865..e547084e0 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -30,23 +30,16 @@ GC.gc() @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] m = ResNet(sz) @test size(m(x_256)) == (1000, 1) - if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) - else - @test_throws ArgumentError ResNet(sz, pretrain = true) - end + ## TODO: find a way to port pretrained models to the new ResNet API + # if (ResNet, sz) in PRETRAINED_MODELS + # @test acctest(ResNet(sz, pretrain = true)) + # else + # @test_throws ArgumentError ResNet(sz, pretrain = true) + # end @test gradtest(m, x_256) GC.safepoint() GC.gc() end - - @testset "Shortcut C" begin - m = Metalhead.resnet(Metalhead.basicblock, :C; - channel_config = [1, 1], - block_config = [2, 2, 2, 2]) - @test size(m(x_256)) == (1000, 1) - @test gradtest(m, x_256) - end end GC.safepoint() From de079bcb39314d53b8bfd3a421034a2db70a48ef Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 23 Jun 2022 12:32:21 +0530 Subject: [PATCH 05/64] Tweaks - I --- src/convnets/inception.jl | 2 +- src/layers/attention.jl | 8 ++++---- src/layers/normalise.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index e4106e957..ead229551 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -340,7 +340,7 @@ struct Inceptionv4 end function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionv4(; inchannels, dropout, nclasses) + layers = inceptionv4(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "Inceptionv4") return Inceptionv4(layers) end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index b6e7b7678..3cefe7c0d 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,5 +1,5 @@ """ - MHAttention(nheads::Integer, qkv_layer, attn_drop, projection) + MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection) Multi-head self-attention layer. @@ -34,9 +34,9 @@ function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = fals attn_drop_rate = 0.0, proj_drop_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_drop_rate) + attn_drop_rate = Dropout(attn_drop_rate) proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) - return MHAttention(nheads, qkv_layer, attn_drop, proj) + return MHAttention(nheads, qkv_layer, attn_drop_rate, proj) end @functor MHAttention @@ -52,7 +52,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} seq_len * batch_size) query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) - attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) + attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) pre_projection = reshape(batched_mul(attention, value_reshaped), diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 2d5e6399a..c767bd1e0 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -24,4 +24,4 @@ function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) return ChannelLayerNorm(diag, ϵ) end -(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) +(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) From 4fa28d42ca6345c69d46cd029a579f715b3070b2 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 25 Jun 2022 15:42:22 +0530 Subject: [PATCH 06/64] Make pretrain condition explicit --- src/convnets/alexnet.jl | 4 +++- src/convnets/densenet.jl | 4 +++- src/convnets/googlenet.jl | 4 +++- src/convnets/inception.jl | 16 ++++++++++++---- src/convnets/mobilenet.jl | 11 +++++++++-- src/convnets/squeezenet.jl | 4 +++- 6 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 405272dd2..87f2c288e 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -49,7 +49,9 @@ end function AlexNet(; pretrain = false, nclasses = 1000) layers = alexnet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "AlexNet") + if pretrain + loadpretrain!(layers, "AlexNet") + end return AlexNet(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 374909bb1..9da4e08b2 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -162,6 +162,8 @@ See also [`Metalhead.densenet`](#). function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) @assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))." model = DenseNet(densenet_config[config]; nclasses = nclasses) - pretrain && loadpretrain!(model, string("DenseNet", config)) + if pretrain + loadpretrain!(model, string("DenseNet", config)) + end return model end diff --git a/src/convnets/googlenet.jl b/src/convnets/googlenet.jl index 318463494..946d0d7f7 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/googlenet.jl @@ -86,7 +86,9 @@ end function GoogLeNet(; pretrain = false, nclasses = 1000) layers = googlenet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "GoogLeNet") + if pretrain + loadpretrain!(layers, "GoogLeNet") + end return GoogLeNet(layers) end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index ead229551..ba30fa86f 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -182,7 +182,9 @@ end function Inceptionv3(; pretrain = false, nclasses = 1000) layers = inceptionv3(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "Inceptionv3") + if pretrain + loadpretrain!(layers, "Inceptionv3") + end return Inceptionv3(layers) end @@ -341,7 +343,9 @@ end function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionv4(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "Inceptionv4") + if pretrain + loadpretrain!(layers, "Inceptionv4") + end return Inceptionv4(layers) end @@ -476,7 +480,9 @@ end function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "InceptionResNetv2") + if pretrain + loadpretrain!(layers, "InceptionResNetv2") + end return InceptionResNetv2(layers) end @@ -584,7 +590,9 @@ Creates an Xception model. """ function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = xception(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "xception") + if pretrain + loadpretrain!(layers, "xception") + end return Xception(layers) end diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 93eba1c06..b7dfcd6f3 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -90,7 +90,9 @@ end function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv1")) + if pretrain + loadpretrain!(layers, string("MobileNetv1")) + end return MobileNetv1(layers) end @@ -189,6 +191,9 @@ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) pretrain && loadpretrain!(layers, string("MobileNetv2")) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end return MobileNetv2(layers) end @@ -319,7 +324,9 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = max_width = (mode == :large) ? 1280 : 1024 layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv3", mode)) + if pretrain + loadpretrain!(layers, string("MobileNetv3", mode)) + end return MobileNetv3(layers) end diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index c4de36acc..df458f9ff 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -68,7 +68,9 @@ end function SqueezeNet(; pretrain = false) layers = squeezenet() - pretrain && loadpretrain!(layers, "SqueezeNet") + if pretrain + loadpretrain!(layers, "SqueezeNet") + end return SqueezeNet(layers) end From 7846f8bd8aac685a908f23af3b2183c63d84b9b7 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 28 Jun 2022 18:39:41 +0530 Subject: [PATCH 07/64] More declarative interface for ResNet 1. Less keywords for the user to worry about 2. Delete `ResNeXt` just for now --- src/convnets/efficientnet.jl | 111 ++++++------ src/convnets/mobilenet.jl | 2 +- src/convnets/resne(x)t.jl | 315 +++++++++-------------------------- 3 files changed, 134 insertions(+), 294 deletions(-) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 1465eb238..da9000468 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -6,19 +6,21 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). # Arguments -- `scalings`: global width and depth scaling (given as a tuple) -- `block_config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - `n`: number of block repetitions (will be scaled by global depth scaling) - - `k`: kernel size - - `s`: kernel stride - - `e`: expansion ratio - - `i`: block input channels (will be scaled by global width scaling) - - `o`: block output channels (will be scaled by global width scaling) -- `inchannels`: number of input channels -- `nclasses`: number of output classes -- `max_width`: maximum number of output channels before the fully connected - classification blocks + - `scalings`: global width and depth scaling (given as a tuple) + + - `block_config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `n`: number of block repetitions (will be scaled by global depth scaling) + + `k`: kernel size + + `s`: kernel stride + + `e`: expansion ratio + + `i`: block input channels (will be scaled by global width scaling) + + `o`: block output channels (will be scaled by global width scaling) + - `inchannels`: number of input channels + - `nclasses`: number of output classes + - `max_width`: maximum number of output channels before the fully connected + classification blocks """ function efficientnet(scalings, block_config; inchannels = 3, nclasses = 1000, max_width = 1280) @@ -64,34 +66,33 @@ end # i: block input channels # o: block output channels const efficientnet_block_configs = [ -# (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), + # (n, k, s, e, i, o) + (1, 3, 1, 1, 32, 16), + (2, 3, 2, 6, 16, 24), + (2, 5, 2, 6, 24, 40), + (3, 3, 2, 6, 40, 80), + (3, 5, 1, 6, 80, 112), (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320) + (1, 3, 1, 6, 192, 320), ] # w: width scaling # d: depth scaling # r: image resolution const efficientnet_global_configs = Dict( -# ( r, ( w, d)) - :b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6)) -) + # (r, (w, d)) + :b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) struct EfficientNet - layers::Any + layers::Any end """ @@ -103,27 +104,29 @@ See also [`efficientnet`](#). # Arguments -- `scalings`: global width and depth scaling (given as a tuple) -- `block_config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - `n`: number of block repetitions (will be scaled by global depth scaling) - - `k`: kernel size - - `s`: kernel stride - - `e`: expansion ratio - - `i`: block input channels (will be scaled by global width scaling) - - `o`: block output channels (will be scaled by global width scaling) -- `inchannels`: number of input channels -- `nclasses`: number of output classes -- `max_width`: maximum number of output channels before the fully connected - classification blocks + - `scalings`: global width and depth scaling (given as a tuple) + + - `block_config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `n`: number of block repetitions (will be scaled by global depth scaling) + + `k`: kernel size + + `s`: kernel stride + + `e`: expansion ratio + + `i`: block input channels (will be scaled by global width scaling) + + `o`: block output channels (will be scaled by global width scaling) + - `inchannels`: number of input channels + - `nclasses`: number of output classes + - `max_width`: maximum number of output channels before the fully connected + classification blocks """ function EfficientNet(scalings, block_config; inchannels = 3, nclasses = 1000, max_width = 1280) - layers = efficientnet(scalings, block_config; - inchannels = inchannels, - nclasses = nclasses, - max_width = max_width) - return EfficientNet(layers) + layers = efficientnet(scalings, block_config; + inchannels = inchannels, + nclasses = nclasses, + max_width = max_width) + return EfficientNet(layers) end @functor EfficientNet @@ -141,13 +144,13 @@ See also [`efficientnet`](#). # Arguments -- `name`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) -- `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `name`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) @assert name in keys(efficientnet_global_configs) - "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" + "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) pretrain && loadpretrain!(model, string("efficientnet-", name)) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index b7dfcd6f3..25067a631 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -28,7 +28,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). function mobilenetv1(width_mult, config; activation = relu, inchannels = 3, - fcsize = 1024, + fcsize = 1024, nclasses = 1000) layers = [] for (dw, outch, stride, nrepeats) in config diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index abf7193b1..50b75fbfd 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -8,18 +8,18 @@ function _drop_blocks(drop_block_prob = 0.0) ] end -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, +function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, first_dilation = nothing, norm_layer = BatchNorm) kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size first_dilation = kernel_size[1] > 1 ? (!isnothing(first_dilation) ? first_dilation : dilation) : 1 pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, dilation = first_dilation, bias = false), - norm_layer(out_channels)) + norm_layer(outplanes)) end -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, +function downsample_avg(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, first_dilation = nothing, norm_layer = BatchNorm) avg_stride = dilation == 1 ? stride : 1 if stride == 1 && dilation == 1 @@ -29,8 +29,8 @@ function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dila pool = MeanPool((2, 2); stride = avg_stride, pad) end return Chain(pool, - Conv((1, 1), in_channels => out_channels; bias = false), - norm_layer(out_channels)) + Conv((1, 1), inplanes => outplanes; bias = false), + norm_layer(outplanes)) end function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, @@ -78,16 +78,61 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina end expansion_factor(::typeof(bottleneck)) = 4 +function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, + norm_layer = BatchNorm, activation = relu) + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + # Main stem + inplanes = stem_type == :deep ? stem_width * 2 : 64 + if stem_type == :deep + stem_channels = (stem_width, stem_width) + if stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + bias = false), + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes, activation) + # Stem pooling + if replace_stem_pool + stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, + bias = false), + norm_layer(inplanes, activation)) + else + stempool = MaxPool((3, 3); stride = 2, pad = 1) + end + return Chain(conv1, bn1, stempool) +end + +function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), + stride = 1, dilation = 1, first_dilation = dilation, + norm_layer = BatchNorm) + if stride != 1 || inplanes != planes * expansion + downsample = downsample_fn(kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation, + norm_layer) + else + downsample = identity + end + return downsample +end + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function _make_blocks(block_fn, channels, block_repeats, inplanes; - reduce_first = 1, output_stride = 32, down_kernel_size = (1, 1), - avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - kwargs...) +function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, + downsample_fn = downsample_conv, downsample_args::NamedTuple = (), + drop_block_rate = 0.0, drop_path_rate = 0.0, + block_args::NamedTuple = ()) + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" expansion = expansion_factor(block_fn) - kwarg_dict = Dict(kwargs...) stages = [] net_block_idx = 1 net_stride = 4 @@ -103,17 +148,10 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; else net_stride *= stride end - # use average pooling for projection skip connection between stages/downsample. - downsample = identity - if stride != 1 || inplanes != planes * expansion - downsample_fn = avg_down ? downsample_avg : downsample_conv - downsample = downsample_fn(down_kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation = dilation, - norm_layer = kwarg_dict[:norm_layer]) - end - # arguments to be passed into the block function - block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => drop_block, kwargs...) + # Downsample block; either a (default) convolution-based block or a pooling-based block. + downsample = downsample_block(downsample_fn, inplanes, planes, expansion; + downsample_args...) + # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity @@ -123,7 +161,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, - drop_path = DropPath(block_dpr), block_kwargs...)) + drop_path = DropPath(block_dpr), drop_block, block_args...)) prev_dilation = dilation inplanes = planes * expansion net_block_idx += 1 @@ -133,103 +171,25 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; return Chain(stages...) end -""" - resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), - avg_down = false, activation = relu, norm_layer = BatchNorm, - drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, - block_kwargs...) - -Creates the layers of a ResNe(X)t model. If you are an end-user, you should probably use -[ResNet](@ref) instead and pass in the parameters you want to modify as optional parameters -there. - -# Arguments: - - - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for - example. - - - `layers`: A list of integers representing the number of blocks in each stage. - - `nclasses`: The number of output classes. The default value is 1000. - - `inchannels`: The number of input channels to the model. The default value is 3. - - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - This is used for [ResNeXt](@ref)-like models. The default value is 1. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels: `planes * base_width / 64 * cardinality`. - The default value is 64. - - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. - - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: - - + `:default` - a single 7x7 convolution layer with a width of `stem_width` - + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` - + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. - The default value is `:default`. - - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a - convolution layer. The default value is false. - - `reduce_first`: Reduction factor for first convolution output width of residual blocks, - Default is 1 for all architectures except SE-Nets, where it is 2. - - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the - skip connection. The default value is (1, 1) for all architectures except - SE-Nets, where it is (3, 3). - - `avg_down`: Use average pooling for projection skip connection between stages/downsample. - - `activation`: The activation function to use. The default value is `relu`. - - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. - - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `DropPath`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. - - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. - -If you are an end-user trying to tweak the ResNet model, note that there is no guarantee that -all combinations of parameters will work. In particular, tweaking `block_kwargs` is not -advised unless you know what you are doing. -""" function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), - avg_down = false, activation = relu, norm_layer = BatchNorm, - drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, - block_kwargs...) - @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" - @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + stem_fn = resnet_stem, stem_args::NamedTuple = (), + downsample_fn = downsample_conv, downsample_args::NamedTuple = (), + drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = 0.0), + block_args::NamedTuple = ()) # Stem - inplanes = stem_type == :deep ? stem_width * 2 : 64 - if stem_type == :deep - stem_channels = (stem_width, stem_width) - if stem_type == :deep_tiered - stem_channels = (3 * (stem_width ÷ 4), stem_width) - end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, - bias = false), - norm_layer(stem_channels[1], activation), - Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, - bias = false), - norm_layer(stem_channels[2], activation), - Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) - else - conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) - end - bn1 = norm_layer(inplanes, activation) - # Stem pooling - if replace_stem_pool - stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, - bias = false), - norm_layer(inplanes, activation)) - else - stempool = MaxPool((3, 3); stride = 2, pad = 1) - end - stem = Chain(conv1, bn1, stempool) + stem = stem_fn(; inchannels, stem_args...) # Feature Blocks channels = [64, 128, 256, 512] - stage_blocks = _make_blocks(block, channels, layers, inplanes; cardinality, base_width, - output_stride, reduce_first, avg_down, - down_kernel_size, activation, norm_layer, - drop_block_rate, drop_path_rate, block_kwargs...) + stage_blocks = _make_blocks(block, channels, layers, inchannels; + output_stride, downsample_fn, downsample_args, + drop_block_rate = drop_rates.drop_block_rate, + drop_path_rate = drop_rates.drop_path_rate, + block_args) # Head (Pooling and Classifier) expansion = expansion_factor(block) num_features = 512 * expansion - classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, + classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten, Dense(num_features, nclasses)) return Chain(Chain(stem, stage_blocks), classifier) end @@ -239,59 +199,6 @@ const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) - -""" - ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) - -Creates a ResNet model. -((reference)[https://arxiv.org/abs/1512.03385]) - -# Arguments: - - - `depth`: The depth of the `ResNet` model. Must be one of `[18, 34, 50, 101, 152]`. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. - - `nclasses`: The number of output classes. The default value is 1000. - -Apart from these, the model can also take any of the following optional arguments: - - - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for - example. - - - `layers`: A list of integers representing the number of blocks in each stage. - - `inchannels`: The number of input channels to the model. The default value is 3. - - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - This is used for [ResNeXt](@ref)-like models. The default value is 1. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels: `planes * base_width / 64 * cardinality`. - The default value is 64. - - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. - - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: - - + `:default` - a single 7x7 convolution layer with a width of `stem_width` - + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` - + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. - The default value is `:default`. - - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a - convolution layer. The default value is false. - - `reduce_first`: Reduction factor for first convolution output width of residual blocks, - Default is 1 for all architectures except SE-Nets, where it is 2. - - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the - skip connection. The default value is (1, 1) for all architectures except - SE-Nets, where it is (3, 3). - - `avg_down`: Use average pooling for projection skip connection between stages/downsample. - - `activation`: The activation function to use. The default value is `relu`. - - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. - - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. - - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. - -See also [`resnet`](@ref) for more details. - -!!! warning - - Pretrained models are not supported for all parameter combinations of `ResNet`. -""" struct ResNet layers::Any end @@ -300,78 +207,8 @@ end function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" model = resnet(resnet_config[depth]...; nclasses, kwargs...) - pretrain && loadpretrain!(model, string("resnet", depth)) - return model -end - -""" - ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, nclasses = 1000, - kwargs...) - -Creates a ResNeXt model. -((reference)[https://arxiv.org/abs/1611.05431]) - -# Arguments: - - - `depth`: The depth of the `ResNeXt` model. Must be one of `[50, 101, 152]`. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - of the `ResNeXt` mode. The default value is 4. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels. The default value is 32. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. - - `nclasses`: The number of output classes. The default value is 1000. - -Apart from these, the model can also take any of the following optional arguments: - - - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for - example. - - - `layers`: A list of integers representing the number of blocks in each stage. - - `inchannels`: The number of input channels to the model. The default value is 3. - - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - This is used for [ResNeXt](@ref)-like models. The default value is 1. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels: `planes * base_width / 64 * cardinality`. - The default value is 64. - - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. - - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: - - + `:default` - a single 7x7 convolution layer with a width of `stem_width` - + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` - + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. - The default value is `:default`. - - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a - convolution layer. The default value is false. - - `reduce_first`: Reduction factor for first convolution output width of residual blocks, - Default is 1 for all architectures except SE-Nets, where it is 2. - - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the - skip connection. The default value is (1, 1) for all architectures except - SE-Nets, where it is (3, 3). - - `avg_down`: Use average pooling for projection skip connection between stages/downsample. - - `activation`: The activation function to use. The default value is `relu`. - - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. - - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. - - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. - -See also [`resnet`](@ref) for more details. - -!!! warning - - Pretrained models are not currently supported for `ResNeXt`. -""" -struct ResNeXt - layers::Any -end -@functor ResNeXt - -function ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, - nclasses = 1000, - kwargs...) - @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(resnet_config[depth]...; cardinality, base_width, nclasses, kwargs...) - pretrain && - loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + if pretrain + loadpretrain!(model, string("resnet", depth)) + end return model end From a1d5ddc7b978b772b789560f90c5d5d54f03bbb6 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 28 Jun 2022 22:58:04 +0530 Subject: [PATCH 08/64] Make `DropBlock` really work --- Project.toml | 3 +- src/Metalhead.jl | 4 +-- src/convnets/resne(x)t.jl | 16 ++++----- src/layers/Layers.jl | 2 ++ src/layers/drop.jl | 75 ++++++++++++++++++++++----------------- 5 files changed, 57 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index bd618e534..546717e39 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,15 @@ version = "0.7.3" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 5463c64de..a9251150f 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -39,7 +39,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -48,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, :ResNet, +for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 50b75fbfd..eadc3d047 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -107,7 +107,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - return Chain(conv1, bn1, stempool) + return inplanes, Chain(conv1, bn1, stempool) end function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), @@ -150,7 +150,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride end # Downsample block; either a (default) convolution-based block or a pooling-based block. downsample = downsample_block(downsample_fn, inplanes, planes, expansion; - downsample_args...) + stride, dilation, first_dilation = dilation, downsample_args...) # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks @@ -172,16 +172,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride end function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - stem_fn = resnet_stem, stem_args::NamedTuple = (), - downsample_fn = downsample_conv, downsample_args::NamedTuple = (), + stem_fn = resnet_stem, stem_args::NamedTuple = NamedTuple(), + downsample_fn = downsample_conv, downsample_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - block_args::NamedTuple = ()) + drop_block_rate = 0.5), + block_args::NamedTuple = NamedTuple()) # Stem - stem = stem_fn(; inchannels, stem_args...) + inplanes, stem = stem_fn(; inchannels, stem_args...) # Feature Blocks channels = [64, 128, 256, 512] - stage_blocks = _make_blocks(block, channels, layers, inchannels; + stage_blocks = _make_blocks(block, channels, layers, inplanes; output_stride, downsample_fn, downsample_args, drop_block_rate = drop_rates.drop_block_rate, drop_path_rate = drop_rates.drop_path_rate, diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 6c417c077..1e75b53d6 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,12 +1,14 @@ module Layers using Flux +using CUDA using NNlib using NNlibCUDA using Functors using ChainRulesCore using Statistics using MLUtils +using Random include("../utilities.jl") diff --git a/src/layers/drop.jl b/src/layers/drop.jl index df66a3e7f..fdbdc7db7 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,28 +1,33 @@ -""" - DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) +function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale, active::Bool = true) where {T} + active || return x + H, W, _, _ = size(x) + total_size = H * W + clipped_block_size = min(block_size, min(H, W)) + gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / + ((W - block_size + 1) * (H - block_size + 1)) + block_mask = rand_like(rng, x) .< gamma + block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) + block_mask = 1 .- block_mask + normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) + return x .* block_mask .* normalize_scale +end +dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) +function dropblock(rng, x::CuArray, p; kwargs...) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays.")) +end -Implements DropBlock, a regularization method for convolutional networks. -([reference](https://arxiv.org/pdf/1810.12890.pdf)) -""" -struct DropBlock{F} +struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F block_size::Integer gamma_scale::F + active::Union{Bool, Nothing} + rng::R end -@functor DropBlock - -(m::DropBlock)(x) = dropblock(x, m.drop_block_prob, m.block_size, m.gamma_scale) -function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) - if drop_block_prob == 0.0 - return identity - end - @assert drop_block_prob < 0 || drop_block_prob > 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert gamma_scale < 0 || gamma_scale > 1 - "gamma_scale must be between 0 and 1, got $gamma_scale" - return DropBlock(drop_block_prob, block_size, gamma_scale) -end +@functor DropBlock +trainable(a::DropBlock) = (;) function _dropblock_checks(x::T) where {T} if !(T <: AbstractArray) @@ -34,20 +39,26 @@ function _dropblock_checks(x::T) where {T} end ChainRulesCore.@non_differentiable _dropblock_checks(x) -function dropblock(x::AbstractArray{T, 4}, drop_block_prob, block_size, - gamma_scale) where {T} +function (m::DropBlock)(x) _dropblock_checks(x) - H, W, _, _ = size(x) - total_size = H * W - clipped_block_size = min(block_size, min(H, W)) - gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / - ((W - block_size + 1) * (H - block_size + 1)) - block_mask = rand_like(x) .< gamma - block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); - stride = 1, pad = clipped_block_size ÷ 2) - block_mask = 1 .- block_mask - normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) - return x .* block_mask .* normalize_scale + Flux._isactive(m) || return x + return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) +end + +function Flux.testmode!(m::DropBlock, mode = true) + return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end + +function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = Flux.rng_from_array()) + if drop_block_prob == 0.0 + return identity + end + @assert 0 ≤ drop_block_prob ≤ 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0 ≤ gamma_scale ≤ 1 + "gamma_scale must be between 0 and 1, got $gamma_scale" + return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) end """ From 3be1d81bb240afc8e3c70cffd9dccc99c1362f3f Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 29 Jun 2022 08:33:23 +0530 Subject: [PATCH 09/64] Construct the stem outside and pass it into `resnet` `downsample_args` is actually redundant --- src/Metalhead.jl | 1 - src/convnets/resne(x)t.jl | 61 +++++++++++++++++++-------------------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index a9251150f..34610c548 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,7 +7,6 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils -using Random import Functors diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index eadc3d047..4c77260f8 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -1,13 +1,3 @@ -# returns `DropBlock`s for each block of the ResNet -function _drop_blocks(drop_block_prob = 0.0) - return [ - identity, - identity, - DropBlock(drop_block_prob, 5, 0.25), - DropBlock(drop_block_prob, 3, 1.00), - ] -end - function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, first_dilation = nothing, norm_layer = BatchNorm) kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size @@ -80,12 +70,15 @@ expansion_factor(::typeof(bottleneck)) = 4 function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) - @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + @assert stem_type in [:default, :deep, :deep_tiered] + "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem - inplanes = stem_type == :deep ? stem_width * 2 : 64 - if stem_type == :deep - stem_channels = (stem_width, stem_width) - if stem_type == :deep_tiered + deep_stem = stem_type == :deep || stem_type == :deep_tiered + inplanes = deep_stem ? stem_width * 2 : 64 + if deep_stem + if stem_type == :deep + stem_channels = (stem_width, stem_width) + elseif stem_type == :deep_tiered stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, @@ -107,7 +100,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - return inplanes, Chain(conv1, bn1, stempool) + return Chain(conv1, bn1, stempool), inplanes end function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), @@ -128,9 +121,8 @@ end # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, - downsample_fn = downsample_conv, downsample_args::NamedTuple = (), - drop_block_rate = 0.0, drop_path_rate = 0.0, - block_args::NamedTuple = ()) + downsample_fn = downsample_conv, + drop_rates::NamedTuple, block_args::NamedTuple) @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" expansion = expansion_factor(block_fn) stages = [] @@ -139,7 +131,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride dilation = prev_dilation = 1 for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, - _drop_blocks(drop_block_rate))) + _drop_blocks(drop_rates.drop_block_rate))) # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -148,16 +140,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride else net_stride *= stride end - # Downsample block; either a (default) convolution-based block or a pooling-based block. + # Downsample block; either a (default) convolution-based block or a pooling-based block downsample = downsample_block(downsample_fn, inplanes, planes, expansion; - stride, dilation, first_dilation = dilation, downsample_args...) + stride, dilation, first_dilation = dilation) # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 # stochastic depth linear decay rule - block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + block_dpr = drop_rates.drop_path_rate * net_block_idx / (sum(block_repeats) - 1) push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, @@ -171,21 +163,26 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride return Chain(stages...) end +# returns `DropBlock`s for each block of the ResNet +function _drop_blocks(drop_block_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_block_prob, 5, 0.25), + DropBlock(drop_block_prob, 3, 1.00), + ] +end + function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - stem_fn = resnet_stem, stem_args::NamedTuple = NamedTuple(), - downsample_fn = downsample_conv, downsample_args::NamedTuple = NamedTuple(), + stem = first(resnet_stem(; inchannels)), inplanes = 64, + downsample_fn = downsample_conv, drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.5), + drop_block_rate = 0.0), block_args::NamedTuple = NamedTuple()) - # Stem - inplanes, stem = stem_fn(; inchannels, stem_args...) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = _make_blocks(block, channels, layers, inplanes; - output_stride, downsample_fn, downsample_args, - drop_block_rate = drop_rates.drop_block_rate, - drop_path_rate = drop_rates.drop_path_rate, - block_args) + output_stride, downsample_fn, drop_rates, block_args) # Head (Pooling and Classifier) expansion = expansion_factor(block) num_features = 512 * expansion From 16cbcd0cc8d2559a5554edc6d465f23aeb67d13e Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 29 Jun 2022 18:40:00 +0530 Subject: [PATCH 10/64] Add ResNeXt back Also add tests. A lot of tests --- src/Metalhead.jl | 5 ++-- src/convnets/resne(x)t.jl | 28 ++++++++++++++----- src/layers/Layers.jl | 3 +-- src/layers/drop.jl | 4 +-- src/utilities.jl | 18 ------------- test/convnets.jl | 57 ++++++++++++++++++++++++++++----------- 6 files changed, 69 insertions(+), 46 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 34610c548..172f01d16 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,6 +7,7 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils +using Random import Functors @@ -38,7 +39,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -47,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt, +for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 4c77260f8..a596fc2d1 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -173,7 +173,7 @@ function _drop_blocks(drop_block_prob = 0.0) ] end -function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, +function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, downsample_fn = downsample_conv, drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, @@ -181,10 +181,10 @@ function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = block_args::NamedTuple = NamedTuple()) # Feature Blocks channels = [64, 128, 256, 512] - stage_blocks = _make_blocks(block, channels, layers, inplanes; + stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fn, drop_rates, block_args) # Head (Pooling and Classifier) - expansion = expansion_factor(block) + expansion = expansion_factor(block_fn) num_features = 512 * expansion classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten, Dense(num_features, nclasses)) @@ -201,11 +201,27 @@ struct ResNet end @functor ResNet -function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) - @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - model = resnet(resnet_config[depth]...; nclasses, kwargs...) +function ResNet(depth::Integer; pretrain = false, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + model = resnet(resnet_config[depth]...; nclasses) if pretrain loadpretrain!(model, string("resnet", depth)) end return model end + +struct ResNeXt + layers::Any +end +@functor ResNeXt + +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, block_args = (; cardinality, base_width)) + if pretrain + loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + end + return model +end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 1e75b53d6..efefd91b2 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -2,8 +2,7 @@ module Layers using Flux using CUDA -using NNlib -using NNlibCUDA +using NNlib, NNlibCUDA using Functors using ChainRulesCore using Statistics diff --git a/src/layers/drop.jl b/src/layers/drop.jl index fdbdc7db7..dbb7ddc34 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -13,8 +13,8 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, bl normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) return x .* block_mask .* normalize_scale end -dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) -function dropblock(rng, x::CuArray, p; kwargs...) +dropoutblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) +function dropblock(rng, x::CuArray, p, args...) throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays.")) end diff --git a/src/utilities.jl b/src/utilities.jl index 0c4f46796..930cc621a 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -9,24 +9,6 @@ function _round_channels(channels, divisor, min_value = divisor) return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels end -""" - addrelu(x, y) - -Convenience function for `(x, y) -> @. relu(x + y)`. -Useful as the `connection` argument for [`resnet`](#). -See also [`reluadd`](#). -""" -addrelu(x, y) = @. relu(x + y) - -""" - reluadd(x, y) - -Convenience function for `(x, y) -> @. relu(x) + relu(y)`. -Useful as the `connection` argument for [`resnet`](#). -See also [`addrelu`](#). -""" -reluadd(x, y) = @. relu(x) + relu(y) - """ cat_channels(x, y, zs...) diff --git a/test/convnets.jl b/test/convnets.jl index e547084e0..0238b285a 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -27,18 +27,39 @@ GC.safepoint() GC.gc() @testset "ResNet" begin - @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] - m = ResNet(sz) - @test size(m(x_256)) == (1000, 1) - ## TODO: find a way to port pretrained models to the new ResNet API + # Tests for pretrained ResNets + ## TODO: find a way to port pretrained models to the new ResNet API + # @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] # if (ResNet, sz) in PRETRAINED_MODELS # @test acctest(ResNet(sz, pretrain = true)) # else # @test_throws ArgumentError ResNet(sz, pretrain = true) # end - @test gradtest(m, x_256) - GC.safepoint() - GC.gc() + # end + + @testset "resnet" begin + @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3] + ] + @testset for layers in layer_list + drop_list = [ + (drop_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), + (drop_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), + (drop_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), + ] + @testset for drop_rates in drop_list + m = Metalhead.resnet(block_fn, layers; drop_rates) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + GC.safepoint() + GC.gc() + end + end + end end end @@ -47,16 +68,20 @@ GC.gc() @testset "ResNeXt" begin @testset for depth in [50, 101, 152] - m = ResNeXt(depth) - @test size(m(x_224)) == (1000, 1) - if ResNeXt in PRETRAINED_MODELS - @test acctest(ResNeXt(depth, pretrain = true)) - else - @test_throws ArgumentError ResNeXt(depth, pretrain = true) + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if string("resnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + GC.safepoint() + GC.gc() + end end - @test gradtest(m, x_224) - GC.safepoint() - GC.gc() end end From e5294ec5510cd5fd419cdb33abc60c3901d182ad Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 30 Jun 2022 06:50:07 +0530 Subject: [PATCH 11/64] Enable CI for Windows --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6ba45e03c..29ee17f1a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,7 +22,7 @@ jobs: os: - ubuntu-latest - macOS-latest - # - windows-latest + - windows-latest arch: - x64 steps: From a439bdfa06930404c48cc143dbc0048c5bf6557e Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 29 Jun 2022 19:54:22 +0530 Subject: [PATCH 12/64] Add more general implementation of SE layer Also 1. Tweaks - II : Formatting + some docs 2. Groundwork for abstracting out the classifier --- src/convnets/convmixer.jl | 4 +- src/convnets/convnext.jl | 2 +- src/convnets/inception.jl | 51 ++++++++++---------- src/convnets/mobilenet.jl | 12 ++--- src/convnets/resne(x)t.jl | 99 +++++++++++++++++++++++++++++++++------ src/convnets/vgg.jl | 28 +++++------ src/layers/Layers.jl | 8 ++-- src/layers/classifier.jl | 12 +++++ src/layers/conv.jl | 25 ++-------- src/layers/embeddings.jl | 2 +- src/layers/mlp-linear.jl | 21 +++++---- src/layers/pool.jl | 26 ++++++++++ src/layers/selayers.jl | 41 ++++++++++++++++ src/other/mlpmixer.jl | 28 ++++++----- src/vit-based/vit.jl | 20 ++++---- test/convnets.jl | 6 +-- 16 files changed, 260 insertions(+), 125 deletions(-) create mode 100644 src/layers/classifier.jl create mode 100644 src/layers/pool.jl create mode 100644 src/layers/selayers.jl diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index f70473ca5..d36f1a8d5 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -9,7 +9,7 @@ Creates a ConvMixer model. - `planes`: number of planes in the output of each block - `depth`: number of layers - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `kernel_size`: kernel size of the convolutional layers - `patch_size`: size of the patches - `activation`: activation function used after the convolutional layers @@ -45,7 +45,7 @@ Creates a ConvMixer model. # Arguments - `mode`: the mode of the model, either `:base`, `:small` or `:large` - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `activation`: activation function used after the convolutional layers - `nclasses`: number of classes in the output """ diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 3fef58d1d..f3da6dbf3 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -92,7 +92,7 @@ Creates a ConvNeXt model. # Arguments: - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `drop_path_rate`: Stochastic depth rate. - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239) - `nclasses`: number of output classes diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index ba30fa86f..156362cf3 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -281,7 +281,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) + inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -289,10 +289,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) +function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -315,13 +315,13 @@ function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) inceptionv4_c(), inceptionv4_c(), inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(1536, nclasses)) return Chain(body, head) end """ - Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) + Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -330,7 +330,7 @@ Creates an Inceptionv4 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -341,8 +341,9 @@ struct Inceptionv4 layers::Any end -function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionv4(; inchannels, drop_rate, nclasses) +function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionv4(; inchannels, dropout_rate, nclasses) if pretrain loadpretrain!(layers, "Inceptionv4") end @@ -424,7 +425,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -432,10 +433,10 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) +function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -451,13 +452,13 @@ function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), conv_bn((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(1536, nclasses)) return Chain(body, head) end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -466,7 +467,7 @@ Creates an InceptionResNetv2 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -477,9 +478,9 @@ struct InceptionResNetv2 layers::Any end -function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, +function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) + layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) if pretrain loadpretrain!(layers, "InceptionResNetv2") end @@ -504,7 +505,7 @@ Create an Xception block. # Arguments - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `outchannels`: number of output channels. - `nrepeats`: number of repeats of depthwise separable convolution layers. - `stride`: stride by which to downsample the input. @@ -541,7 +542,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000) + xception(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -549,10 +550,10 @@ Creates an Xception model. # Arguments - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) +function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., conv_bn((3, 3), 32, 64; bias = false)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), @@ -562,7 +563,7 @@ function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) xception_block(728, 1024, 2; stride = 2, grow_at_start = false), depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(2048, nclasses)) return Chain(body, head) end @@ -572,7 +573,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -581,15 +582,15 @@ Creates an Xception model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning `Xception` does not currently support pretrained weights. """ -function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = xception(; inchannels, drop_rate, nclasses) +function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, dropout_rate, nclasses) if pretrain loadpretrain!(layers, "xception") end diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 25067a631..15dc037e8 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -21,7 +21,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `fcsize`: The intermediate fully-connected size between the convolution and final layers - `nclasses`: The number of output classes """ @@ -77,7 +77,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet - `nclasses`: The number of output classes @@ -123,7 +123,7 @@ Create a MobileNetv2 model. + `n`: The number of times a block is repeated + `s`: The stride of the convolutional kernel + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes """ @@ -181,7 +181,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet - `nclasses`: The number of output classes @@ -226,7 +226,7 @@ Create a MobileNetv3 model. + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + `s::Integer` - The stride of the convolutional kernel + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes """ @@ -312,7 +312,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `pretrain`: whether to load the pre-trained weights for ImageNet - `nclasses`: the number of output classes diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index a596fc2d1..74140d625 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -68,6 +68,34 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina end expansion_factor(::typeof(bottleneck)) = 4 +""" + resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, + norm_layer = BatchNorm, activation = relu) + +Builds a stem to be used in a ResNet model. See the `stem` argument of `resnet` for details +on how to use this function. + +# Arguments: + + - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + + + `:default`: Builds a stem based on the default ResNet stem, which consists of a single + 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 + max pooling layer with stride 2. + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2 for one) in using a + deeper stem with 3 successive 3x3 convolutions having normalisation layers + after each one. This is followed by a 3x3 max pooling layer with stride 2. + + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second + convolution. This is an experimental variant from the `timm` library + in Python that shows peformance improvements over the `:deep` stem + in some cases. + + - `inchannels`: The number of channels in the input. + - `replace_stem_pool`: Whether to replace the default 3x3 max pooling layer with a + 3x3 convolution with stride 2 and a normalisation layer. + - `norm_layer`: The normalisation layer used in the stem. + - `activation`: The activation function used in the stem. +""" function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) @assert stem_type in [:default, :deep, :deep_tiered] @@ -75,13 +103,14 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = # Main stem deep_stem = stem_type == :deep || stem_type == :deep_tiered inplanes = deep_stem ? stem_width * 2 : 64 + # Deep stem that uses three successive 3x3 convolutions instead of a single 7x7 convolution if deep_stem if stem_type == :deep stem_channels = (stem_width, stem_width) elseif stem_type == :deep_tiered stem_channels = (3 * (stem_width ÷ 4), stem_width) end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[1]; stride = 2, pad = 1, bias = false), norm_layer(stem_channels[1], activation), Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, @@ -129,9 +158,10 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 + dbr = haskey(drop_rates, :drop_block_rate) ? drop_rates.drop_block_rate : 0 for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, - _drop_blocks(drop_rates.drop_block_rate))) + _drop_blocks(dbr))) # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -149,7 +179,8 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 # stochastic depth linear decay rule - block_dpr = drop_rates.drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + dpr = haskey(drop_rates, :drop_path_rate) ? drop_rates.drop_path_rate : 0 + block_dpr = dpr * net_block_idx / (sum(block_repeats) - 1) push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, @@ -163,22 +194,20 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride return Chain(stages...) end -# returns `DropBlock`s for each block of the ResNet +# returns `DropBlock`s for each stage of the ResNet function _drop_blocks(drop_block_prob = 0.0) return [ - identity, - identity, - DropBlock(drop_block_prob, 5, 0.25), - DropBlock(drop_block_prob, 3, 1.00), + identity, identity, + DropBlock(drop_block_prob, 5, 0.25), DropBlock(drop_block_prob, 3, 1.00), ] end function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, - downsample_fn = downsample_conv, - drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, + downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), + drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), - block_args::NamedTuple = NamedTuple()) + classifier_args::NamedTuple = NamedTuple()) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; @@ -186,11 +215,13 @@ function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride # Head (Pooling and Classifier) expansion = expansion_factor(block_fn) num_features = 512 * expansion - classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten, - Dense(num_features, nclasses)) + global_pool, fc = create_classifier(num_features, nclasses; classifier_args...) + dr = haskey(drop_rates, :dropout_rate) ? drop_rates.dropout_rate : 0 + classifier = Chain(global_pool, Dropout(dr), fc) return Chain(Chain(stem, stage_blocks), classifier) end +# block-layer configurations for ResNet and ResNeXt models const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), 34 => (basicblock, [3, 4, 6, 3]), 50 => (bottleneck, [3, 4, 6, 3]), @@ -201,6 +232,23 @@ struct ResNet end @functor ResNet +""" + ResNet(depth::Integer; pretrain = false, nclasses = 1000) + +Creates a ResNet model with the specified depth. + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +!!! warning + + `ResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" function ResNet(depth::Integer; pretrain = false, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" @@ -216,10 +264,31 @@ struct ResNeXt end @functor ResNeXt -function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) +""" + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + +Creates a ResNeXt model with the specified depth, cardinality, and base width. + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `nclasses`: the number of output classes + +!!! warning + + `ResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, block_args = (; cardinality, base_width)) + model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, + block_args = (; cardinality, base_width)) if pretrain loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 15560de7c..957a0a483 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -52,7 +52,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) + vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -63,19 +63,19 @@ Create VGG classifier (fully connected) layers the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `drop_rate`: the dropout level between each fully connected layer + - `dropout_rate`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) +function vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) return Chain(MLUtils.flatten, Dense(Int(prod(imsize)), fcsize, relu), - Dropout(drop_rate), + Dropout(dropout_rate), Dense(fcsize, fcsize, relu), - Dropout(drop_rate), + Dropout(dropout_rate), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -90,12 +90,12 @@ Create a VGG model - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `drop_rate`: dropout level between fully connected layers + - `dropout_rate`: dropout level between fully connected layers """ -function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) +function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) + class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) return Chain(Chain(conv), class) end @@ -114,7 +114,7 @@ struct VGG end """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -126,11 +126,11 @@ Construct a VGG model with the specified input image size. Typically, the image - `nclasses`::Integer : number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `drop_rate`: dropout level between fully connected layers + - `dropout_rate`: dropout level between fully connected layers """ function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, - drop_rate) - layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, drop_rate) + dropout_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, dropout_rate) return VGG(layers) end @@ -159,7 +159,7 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses batchnorm = batchnorm, nclasses = nclasses, fcsize = 4096, - drop_rate = 0.5) + dropout_rate = 0.5) if pretrain && !batchnorm loadpretrain!(model, string("vgg", depth)) elseif pretrain diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index efefd91b2..f58f40172 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -17,14 +17,16 @@ include("mlp-linear.jl") include("normalise.jl") include("conv.jl") include("drop.jl") +include("selayers.jl") +include("classifier.jl") export MHAttention, PatchEmbedding, ViPosEmbedding, ClassTokens, mlp_block, gated_mlp_block, - LayerScale, DropPath, + LayerScale, DropPath, DropBlock, ChannelLayerNorm, prenorm, skip_identity, skip_projection, conv_bn, depthwise_sep_conv_bn, - invertedresidual, squeeze_excite, - DropBlock + squeeze_excite, effective_squeeze_excite, + invertedresidual, create_classifier end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl new file mode 100644 index 000000000..04be6ef86 --- /dev/null +++ b/src/layers/classifier.jl @@ -0,0 +1,12 @@ +function create_classifier(inplanes, nclasses; pool_type = :avg, use_conv = false) + flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling + if pool_type == :identity + @assert use_conv + "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" + flatten_in_pool = false # disable flattening if pooling is pass-through (no pooling) + end + global_pool = SelectAdaptivePool(; pool_type, flatten = flatten_in_pool) + fc = use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : + Dense(inplanes => nclasses; bias = true) + return global_pool, fc +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 6363946d0..e56967aef 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -144,27 +144,6 @@ function skip_identity(inplanes, outplanes) end skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) -""" - squeeze_excite(channels, reduction = 4) - -Squeeze and excitation layer used by MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `channels`: the number of input/output feature maps - - `reduction = 4`: the reduction factor for the number of hidden feature maps - (must be ≥ 1) -""" -function squeeze_excite(channels, reduction = 4) - @assert (reduction>=1) "`reduction` must be >= 1" - return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - conv_bn((1, 1), channels, channels ÷ reduction, relu; - bias = false)..., - conv_bn((1, 1), channels ÷ reduction, channels, hardσ)...), - .*) -end - """ invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; stride, reduction = nothing) @@ -190,7 +169,9 @@ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, pad = @. (kernel_size - 1) ÷ 2 conv1 = (inplanes == hidden_planes) ? identity : Chain(conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false)) - selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes, reduction) + selayer = isnothing(reduction) ? identity : + squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, + norm_layer = BatchNorm) invres = Chain(conv1, conv_bn(kernel_size, hidden_planes, hidden_planes, activation; bias = false, stride, pad = pad, groups = hidden_planes)..., diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index ad079db9d..66f25d1c0 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -11,7 +11,7 @@ patches. # Arguments: - `imsize`: the size of the input image - - `inchannels`: the number of channels in the input. The default value is 3. + - `inchannels`: the number of channels in the input. - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding - `norm_layer`: the normalization layer - by default the identity function but otherwise takes a diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp-linear.jl index 550c2ad22..8cca1e266 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp-linear.jl @@ -15,7 +15,7 @@ end """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - drop_rate =0., activation = gelu) + dropout_rate =0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -24,18 +24,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `drop_rate`: Dropout rate. + - `dropout_rate`: Dropout rate. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - drop_rate = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(drop_rate), - Dense(hidden_planes, outplanes), Dropout(drop_rate)) + dropout_rate = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout_rate), + Dense(hidden_planes, outplanes), Dropout(dropout_rate)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) + outplanes::Integer = inplanes; dropout_rate = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -46,16 +46,17 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `drop_rate`: Dropout rate. + - `dropout_rate`: Dropout rate. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) + outplanes::Integer = inplanes; dropout_rate = 0.0, + activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(drop_rate), + Dropout(dropout_rate), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(drop_rate)) + Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/layers/pool.jl b/src/layers/pool.jl new file mode 100644 index 000000000..aa5755240 --- /dev/null +++ b/src/layers/pool.jl @@ -0,0 +1,26 @@ +function AdaptiveMeanMaxPool(output_size = (1, 1)) + return 0.5 * Parallel(.+, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) +end + +function AdaptiveCatMeanMaxPool(output_size = (1, 1)) + return Parallel(cat_channels, AdaptiveAvgMaxPool(output_size), + AdaptiveMaxPool(output_size)) +end + +function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) + if pool_type == :mean + pool = AdaptiveAvgPool(output_size) + elseif pool_type == :max + pool = AdaptiveMaxPool(output_size) + elseif pool_type == :meanmax + pool = AdaptiveAvgMaxPool(output_size) + elseif pool_type == :catmeanmax + pool = AdaptiveCatAvgMaxPool(output_size) + elseif pool_type = :identity + pool = identity + else + throw(AssertionError("Invalid pool type: $pool_type")) + end + flatten_fn = flatten ? MLUtils.flatten : identity + return Chain(pool, flatten_fn) +end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl new file mode 100644 index 000000000..acd7e9809 --- /dev/null +++ b/src/layers/selayers.jl @@ -0,0 +1,41 @@ +""" + squeeze_excite(inplanes, reduction = 16; rd_divisor = 8, + activation = relu, gate_activation = sigmoid, norm_layer = identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + +Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. + +# Arguments + + - `inplanes`: The number of input feature maps + - `reduction`: The reduction factor for the number of hidden feature maps + - `rd_divisor`: The divisor for the number of hidden feature maps. + - `activation`: The activation function for the first convolution layer + - `gate_activation`: The activation function for the gate layer + - `norm_layer`: The normalization layer to be used after the convolution layers + - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer + Must be ≥ 1 or `nothing` for no squeeze and excite layer. +""" +function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, + activation = relu, gate_activation = sigmoid, norm_layer = identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes => rd_planes), + norm_layer, + activation, + Conv((1, 1), rd_planes => inplanes), + norm_layer, + gate_activation), .*) +end + +""" + effective_squeeze_excite(inplanes, gate_layer = sigmoid) + +Effective squeeze-and-excitation layer. +(reference: [CenterMask : Real-Time Anchor-Free Instance Segmentation](https://arxiv.org/abs/1911.06667)) +""" +function effective_squeeze_excite(inplanes; gate_activation = sigmoid, kwargs...) + return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes, inplanes), + gate_activation), .*) +end diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index ed4c47af3..5083b228e 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - drop_rate =0., drop_path_rate = 0., activation = gelu) + dropout_rate =0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -12,22 +12,22 @@ Creates a feedforward block for the MLPMixer architecture. - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP and/or the channel mixing MLP as a ratio to the number of planes in the block. - `mlp_layer`: the MLP layer to use in the block - - `drop_rate`: the dropout rate to use in the MLP blocks + - `dropout_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] return Chain(SkipConnection(Chain(LayerNorm(planes), swapdims((2, 1, 3)), mlp_layer(npatches, tokenplanes; activation, - drop_rate), + dropout_rate), swapdims((2, 1, 3)), DropPath(drop_path_rate)), +), SkipConnection(Chain(LayerNorm(planes), mlp_layer(planes, channelplanes; activation, - drop_rate), + dropout_rate), DropPath(drop_path_rate)), +)) end @@ -115,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; drop_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; dropout_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. @@ -128,13 +128,14 @@ Creates a block for the ResMixer architecture. - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `mlp_layer`: the MLP block to use - - `drop_rate`: the dropout rate to use in the MLP blocks + - `dropout_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks - `λ`: initialisation constant for the LayerScale """ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, + λ = 1e-4) return Chain(SkipConnection(Chain(Flux.Scale(planes), swapdims((2, 1, 3)), Dense(npatches, npatches), @@ -142,7 +143,8 @@ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, LayerScale(planes, λ), DropPath(drop_path_rate)), +), SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); drop_rate, + mlp_layer(planes, Int(mlp_ratio * planes); + dropout_rate, activation), LayerScale(planes, λ), DropPath(drop_path_rate)), +)) @@ -232,7 +234,7 @@ end """ spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, drop_rate = 0.0, drop_path_rate = 0.0, + norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. @@ -245,19 +247,19 @@ Creates a feedforward block based on the gMLP model architecture described in th - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `norm_layer`: the normalisation layer to use - - `drop_rate`: the dropout rate to use in the MLP blocks + - `dropout_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, drop_rate = 0.0, + mlp_layer = gated_mlp_block, dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) channelplanes = Int(mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), mlp_layer(sgu, planes, channelplanes; activation, - drop_rate), + dropout_rate), DropPath(drop_path_rate)), +) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 686ddc4d5..a06ce6886 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate =0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate =0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,23 +10,23 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `drop_rate`: dropout rate + - `dropout_rate`: dropout rate """ -function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate = 0.0) +function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, MHAttention(planes, nheads; - attn_drop_rate = drop_rate, - proj_drop_rate = drop_rate)), +), + attn_drop_rate = dropout_rate, + proj_drop_rate = dropout_rate)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - drop_rate)), +)) + dropout_rate)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, emb_drop_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. @@ -41,13 +41,13 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `drop_rate`: dropout rate + - `dropout_rate`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" @@ -57,7 +57,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = ViPosEmbedding(embedplanes, npatches + 1), Dropout(emb_drop_rate), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, - drop_rate), + dropout_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end diff --git a/test/convnets.jl b/test/convnets.jl index 0238b285a..887abb410 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -47,9 +47,9 @@ GC.gc() ] @testset for layers in layer_list drop_list = [ - (drop_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), - (drop_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), - (drop_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), + (dropout_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), + (dropout_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), + (dropout_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), ] @testset for drop_rates in drop_list m = Metalhead.resnet(block_fn, layers; drop_rates) From 441ade8e876522480620c210351bbfc8c12c2e3f Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 1 Jul 2022 12:47:22 +0530 Subject: [PATCH 13/64] Tweaks III + Some more docs 1. Reorganise layer imports for easy access 2. Get pooling to work --- src/convnets/resne(x)t.jl | 17 ++++---- src/layers/Layers.jl | 30 +++++++++---- src/layers/classifier.jl | 2 +- src/layers/drop.jl | 88 ++++++++++++++++++++++++++++++++++----- src/layers/pool.jl | 8 ++-- src/layers/selayers.jl | 18 +++++--- 6 files changed, 125 insertions(+), 38 deletions(-) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 74140d625..424c0dd55 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -158,10 +158,14 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 - dbr = haskey(drop_rates, :drop_block_rate) ? drop_rates.drop_block_rate : 0 + # Stochastic depth linear decay rule (DropPath) + dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), + sum(block_repeats)) for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, - _drop_blocks(dbr))) + _drop_blocks(get(drop_rates, + :drop_block_rate, + 0)))) # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -178,13 +182,11 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 - # stochastic depth linear decay rule - dpr = haskey(drop_rates, :drop_path_rate) ? drop_rates.drop_path_rate : 0 - block_dpr = dpr * net_block_idx / (sum(block_repeats) - 1) push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, - drop_path = DropPath(block_dpr), drop_block, block_args...)) + drop_path = DropPath(dp_rates[block_idx]), drop_block, + block_args...)) prev_dilation = dilation inplanes = planes * expansion net_block_idx += 1 @@ -216,8 +218,7 @@ function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride expansion = expansion_factor(block_fn) num_features = 512 * expansion global_pool, fc = create_classifier(num_features, nclasses; classifier_args...) - dr = haskey(drop_rates, :dropout_rate) ? drop_rates.dropout_rate : 0 - classifier = Chain(global_pool, Dropout(dr), fc) + classifier = Chain(global_pool, Dropout(get(drop_rates, :dropout_rate, 0)), fc) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index f58f40172..41a98843e 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,6 +1,7 @@ module Layers using Flux +using Flux: rng_from_array using CUDA using NNlib, NNlibCUDA using Functors @@ -12,21 +13,32 @@ using Random include("../utilities.jl") include("attention.jl") +export MHAttention + include("embeddings.jl") +export PatchEmbedding, ViPosEmbedding, ClassTokens + include("mlp-linear.jl") +export mlp_block, gated_mlp_block, LayerScale + include("normalise.jl") +export prenorm, ChannelLayerNorm + include("conv.jl") +export conv_bn, depthwise_sep_conv_bn, invertedresidual +skip_identity, skip_projection + include("drop.jl") +export DropPath, DropBlock + include("selayers.jl") +export squeeze_excite, effective_squeeze_excite + include("classifier.jl") +export create_classifier + +include("pool.jl") +export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool +SelectAdaptivePool -export MHAttention, - PatchEmbedding, ViPosEmbedding, ClassTokens, - mlp_block, gated_mlp_block, - LayerScale, DropPath, DropBlock, - ChannelLayerNorm, prenorm, - skip_identity, skip_projection, - conv_bn, depthwise_sep_conv_bn, - squeeze_excite, effective_squeeze_excite, - invertedresidual, create_classifier end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl index 04be6ef86..e2a1fe75c 100644 --- a/src/layers/classifier.jl +++ b/src/layers/classifier.jl @@ -1,4 +1,4 @@ -function create_classifier(inplanes, nclasses; pool_type = :avg, use_conv = false) +function create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling if pool_type == :identity @assert use_conv diff --git a/src/layers/drop.jl b/src/layers/drop.jl index dbb7ddc34..c89bd55fe 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,6 +1,27 @@ +""" + dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale, active::Bool = true) + +The dropblock function. If `active` is `true`, for each input, it zeroes out continguous +regions of size `block_size` in the input. Otherwise, it simply returns the input `x`. + +# Arguments + + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. + - `x`: input array + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + refer to [the paper](https://arxiv.org/abs/1810.12890). + +If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. +""" function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale, active::Bool = true) where {T} - active || return x + if !active + return x + end H, W, _, _ = size(x) total_size = H * W clipped_block_size = min(block_size, min(H, W)) @@ -13,12 +34,14 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, bl normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) return x .* block_mask .* normalize_scale end -dropoutblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) + +dropblock(x, p, args...) = dropblock(rng_from_array(x), x, p, args...) +dropblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) function dropblock(rng, x::CuArray, p, args...) - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays.")) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports CUDA.RNG for CuArrays.")) end -struct DropBlock{F, R <: AbstractRNG} +mutable struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F block_size::Integer gamma_scale::F @@ -41,16 +64,36 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x) function (m::DropBlock)(x) _dropblock_checks(x) - Flux._isactive(m) || return x - return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) + if Flux._isactive(m) + return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) + else + return x + end end function Flux.testmode!(m::DropBlock, mode = true) return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) end +""" + DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = rng_from_array()) + +The `DropBlock` layer. While training, it zeroes out continguous regions of +size `block_size` in the input. During inference, it simply returns the input `x`. +((reference)[https://arxiv.org/abs/1810.12890]) + +# Arguments + + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + refer to [the paper](https://arxiv.org/abs/1810.12890). + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. +""" function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, - rng = Flux.rng_from_array()) + rng = rng_from_array()) if drop_block_prob == 0.0 return identity end @@ -61,15 +104,40 @@ function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) end +function Base.show(io::IO, d::DropBlock) + print(io, "DropBlock(", d.drop_block_prob) + print(io, ", block_size = $(repr(d.block_size))") + print(io, ", gamma_scale = $(repr(d.gamma_scale))") + return print(io, ")") +end + """ - DropPath(p) + DropPath(p; [rng = rng_from_array(x)]) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0 and +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and `identity` otherwise. ([reference](https://arxiv.org/abs/1603.09382)) +This layer can be used to drop certain blocks in a residual structure and allow them to +propagate completely through the skip connection. It can be used in two ways: either with +all blocks having the same survival probability or with a linear scaling rule across the +blocks. This is performed only at training time. At test time, the `DropPath` layer is +equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropPath` returning `identity`, which may not be desirable. This usually happens with + a low number of blocks and a high base survival probability, so it is recommended to + use a fixed base survival probability across blocks. If this is not possible, then + a lower base survival probability is recommended. + # Arguments - `p`: rate of Stochastic Depth. + - `rng`: can be used to pass in a custom RNG instead of the default. See `Flux.Dropout` + for more information on the behaviour of this argument. Custom RNGs are only supported + on the CPU. """ -DropPath(p) = p > 0 ? Dropout(p; dims = 4) : identity +DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity diff --git a/src/layers/pool.jl b/src/layers/pool.jl index aa5755240..4ffe298e3 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -9,14 +9,14 @@ end function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) if pool_type == :mean - pool = AdaptiveAvgPool(output_size) + pool = AdaptiveMeanPool(output_size) elseif pool_type == :max pool = AdaptiveMaxPool(output_size) elseif pool_type == :meanmax - pool = AdaptiveAvgMaxPool(output_size) + pool = AdaptiveMeanMaxPool(output_size) elseif pool_type == :catmeanmax - pool = AdaptiveCatAvgMaxPool(output_size) - elseif pool_type = :identity + pool = AdaptiveCatMeanMaxPool(output_size) + elseif pool_type == :identity pool = identity else throw(AssertionError("Invalid pool type: $pool_type")) diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index acd7e9809..7f1a76d59 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -1,5 +1,5 @@ """ - squeeze_excite(inplanes, reduction = 16; rd_divisor = 8, + squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, activation = relu, gate_activation = sigmoid, norm_layer = identity, rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) @@ -17,22 +17,28 @@ Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, - activation = relu, gate_activation = sigmoid, norm_layer = identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + activation = relu, gate_activation = sigmoid, + norm_layer = planes -> identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), Conv((1, 1), inplanes => rd_planes), - norm_layer, + norm_layer(rd_planes), activation, Conv((1, 1), rd_planes => inplanes), - norm_layer, + norm_layer(inplanes), gate_activation), .*) end """ - effective_squeeze_excite(inplanes, gate_layer = sigmoid) + effective_squeeze_excite(inplanes, gate_activation = sigmoid) Effective squeeze-and-excitation layer. (reference: [CenterMask : Real-Time Anchor-Free Instance Segmentation](https://arxiv.org/abs/1911.06667)) + +# Arguments + + - `inplanes`: The number of input feature maps + - `gate_activation`: The activation function for the gate layer """ function effective_squeeze_excite(inplanes; gate_activation = sigmoid, kwargs...) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), From 5d059f5be6ddab9de195f06ca21ce625599115eb Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 3 Jul 2022 10:34:41 +0530 Subject: [PATCH 14/64] Fix `DropBlock` on the GPU --- src/convnets/resne(x)t.jl | 2 +- src/layers/drop.jl | 43 ++++++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 424c0dd55..03ed646c0 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -288,7 +288,7 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, + model = resnet(resnet_config[depth]...; nclasses, block_args = (; cardinality, base_width)) if pretrain loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) diff --git a/src/layers/drop.jl b/src/layers/drop.jl index c89bd55fe..dc6cb3c54 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,3 +1,11 @@ +# Generates the mask to be used for `DropBlock` +@inline function _dropblock_mask(rng, x, gamma, clipped_block_size) + block_mask = Flux.f32(rand_like(rng, x) .< gamma) + return 1 .- maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) +end +ChainRulesCore.@non_differentiable _dropblock_mask(rng, x, gamma, clipped_block_size) + """ dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale, active::Bool = true) @@ -18,28 +26,25 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. """ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, - gamma_scale, active::Bool = true) where {T} - if !active - return x - end + gamma_scale) where {T} H, W, _, _ = size(x) total_size = H * W clipped_block_size = min(block_size, min(H, W)) gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / ((W - block_size + 1) * (H - block_size + 1)) - block_mask = rand_like(rng, x) .< gamma - block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); - stride = 1, pad = clipped_block_size ÷ 2) - block_mask = 1 .- block_mask + block_mask = dropblock_mask(rng, x, gamma, clipped_block_size) normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) return x .* block_mask .* normalize_scale end -dropblock(x, p, args...) = dropblock(rng_from_array(x), x, p, args...) -dropblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) -function dropblock(rng, x::CuArray, p, args...) +## bs is `clipped_block_size` +# Dispatch for GPU +dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) +function dropblock_mask(rng, x::CuArray, gamma, bs) throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports CUDA.RNG for CuArrays.")) end +# Dispatch for CPU +dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) mutable struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F @@ -52,7 +57,11 @@ end @functor DropBlock trainable(a::DropBlock) = (;) -function _dropblock_checks(x::T) where {T} +function _dropblock_checks(x::T, drop_block_prob, gamma_scale) where {T} + @assert 0 ≤ drop_block_prob ≤ 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0 ≤ gamma_scale ≤ 1 + "gamma_scale must be between 0 and 1, got $gamma_scale" if !(T <: AbstractArray) throw(ArgumentError("x must be an `AbstractArray`")) end @@ -60,10 +69,10 @@ function _dropblock_checks(x::T) where {T} throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) end end -ChainRulesCore.@non_differentiable _dropblock_checks(x) +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_scale) function (m::DropBlock)(x) - _dropblock_checks(x) + _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) if Flux._isactive(m) return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) else @@ -87,7 +96,7 @@ size `block_size` in the input. During inference, it simply returns the input `x - `drop_block_prob`: probability of dropping a block - `block_size`: size of the block to drop - - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. @@ -97,10 +106,6 @@ function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, if drop_block_prob == 0.0 return identity end - @assert 0 ≤ drop_block_prob ≤ 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert 0 ≤ gamma_scale ≤ 1 - "gamma_scale must be between 0 and 1, got $gamma_scale" return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) end From 226e96a5d7d663ed310ef3cdc629a3b08e22e298 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 3 Jul 2022 11:57:23 +0530 Subject: [PATCH 15/64] Add `SEResNet` and `SEResNeXt` So much GC, might as well have a function for it --- src/Metalhead.jl | 5 +- src/convnets/{resne(x)t.jl => resnets.jl} | 96 ++++++++++++++++--- test/convnets.jl | 107 ++++++++++------------ test/other.jl | 9 +- test/runtests.jl | 5 + test/vit-based.jl | 3 +- 6 files changed, 145 insertions(+), 80 deletions(-) rename src/convnets/{resne(x)t.jl => resnets.jl} (80%) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 172f01d16..a4dd73785 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,7 +22,7 @@ include("convnets/alexnet.jl") include("convnets/vgg.jl") include("convnets/inception.jl") include("convnets/googlenet.jl") -include("convnets/resne(x)t.jl") +include("convnets/resnets.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") @@ -43,6 +43,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, + SEResNet, SEResNeXt, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt @@ -50,7 +51,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, + :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :SEResNet, :SEResNeXt, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resnets.jl similarity index 80% rename from src/convnets/resne(x)t.jl rename to src/convnets/resnets.jl index 03ed646c0..73f070617 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resnets.jl @@ -26,7 +26,8 @@ end function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) @assert cardinality==1 "`basicblock` only supports cardinality of 1" @assert base_width==64 "`basicblock` does not support changing base width" @@ -40,8 +41,10 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = dilation, dilation = dilation, bias = false), norm_layer(outplanes)) + attn_layer = attn_fn(outplanes; attn_args...) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, + drop_path)), activation) end expansion_factor(::typeof(basicblock)) = 1 @@ -49,7 +52,8 @@ expansion_factor(::typeof(basicblock)) = 1 function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first @@ -61,9 +65,10 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina dilation = first_dilation, groups = cardinality, bias = false), norm_layer(width)) conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) + attn_layer = attn_fn(outplanes; attn_args...) return Chain(Parallel(+, downsample, Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, - drop_path)), + attn_layer, drop_path)), activation) end expansion_factor(::typeof(bottleneck)) = 4 @@ -233,10 +238,13 @@ struct ResNet end @functor ResNet +(m::ResNet)(x) = m.layers(x) + """ ResNet(depth::Integer; pretrain = false, nclasses = 1000) Creates a ResNet model with the specified depth. +((reference)[https://arxiv.org/abs/1512.03385]) # Arguments @@ -253,11 +261,11 @@ Advanced users who want more configuration options will be better served by usin function ResNet(depth::Integer; pretrain = false, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - model = resnet(resnet_config[depth]...; nclasses) + layers = resnet(resnet_config[depth]...; nclasses) if pretrain - loadpretrain!(model, string("resnet", depth)) + loadpretrain!(layers, string("resnet", depth)) end - return model + return ResNet(layers) end struct ResNeXt @@ -265,10 +273,13 @@ struct ResNeXt end @functor ResNeXt +(m::ResNeXt)(x) = m.layers(x) + """ ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) Creates a ResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/abs/1611.05431]) # Arguments @@ -288,10 +299,73 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(resnet_config[depth]...; nclasses, - block_args = (; cardinality, base_width)) + layers = resnet(resnet_config[depth]...; nclasses, + block_args = (; cardinality, base_width)) + if pretrain + loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width)) + end + return ResNeXt(layers) +end + +struct SEResNet + layers::Any +end +@functor SEResNet + +(m::SEResNet)(x) = m.layers(x) + +""" + SEResNet(depth::Integer; pretrain = false, nclasses = 1000) + +Creates a SEResNet model with the specified depth. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `nclasses`: the number of output classes +""" +function SEResNet(depth::Integer; pretrain = false, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + layers = resnet(resnet_config[depth]...; nclasses, + block_args = (; attn_fn = squeeze_excite)) + if pretrain + loadpretrain!(layers, string("seresnet", depth)) + end + return SEResNet(layers) +end + +struct SEResNeXt + layers::Any +end +@functor SEResNeXt + +(m::SEResNeXt)(x) = m.layers(x) + +""" + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + +Creates a SEResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `nclasses`: the number of output classes +""" +function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + layers = resnet(resnet_config[depth]...; nclasses, + block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) if pretrain - loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) end - return model + return SEResNeXt(layers) end diff --git a/test/convnets.jl b/test/convnets.jl index 887abb410..b717574c1 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -3,11 +3,9 @@ @test size(model(x_256)) == (1000, 1) @test_throws ArgumentError AlexNet(pretrain = true) @test gradtest(model, x_256) + _gc() end -GC.safepoint() -GC.gc() - @testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] m = VGG(sz, batchnorm = bn) @@ -18,14 +16,10 @@ GC.gc() @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "ResNet" begin # Tests for pretrained ResNets ## TODO: find a way to port pretrained models to the new ResNet API @@ -55,17 +49,13 @@ GC.gc() m = Metalhead.resnet(block_fn, layers; drop_rates) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end end end -GC.safepoint() -GC.gc() - @testset "ResNeXt" begin @testset for depth in [50, 101, 152] @testset for cardinality in [32, 64] @@ -78,15 +68,43 @@ GC.gc() @test_throws ArgumentError ResNeXt(depth, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end end -GC.safepoint() -GC.gc() +@testset "SEResNet" begin + @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) + @test size(m(x_224)) == (1000, 1) + if string("seresnet", depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNet(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + +@testset "SEResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if string("seresnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + end +end @testset "EfficientNet" begin @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8] @@ -101,14 +119,10 @@ GC.gc() @test_throws ArgumentError EfficientNet(name, pretrain = true) end @test gradtest(m, x) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) @@ -118,11 +132,9 @@ GC.gc() @test_throws ArgumentError GoogLeNet(pretrain = true) end @test gradtest(m, x_224) + _gc() end -GC.safepoint() -GC.gc() - @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @testset "Inceptionv3" begin @@ -135,8 +147,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "Inceptionv4" begin m = Inceptionv4() @test size(m(x_299)) == (1000, 2) @@ -147,8 +158,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "InceptionResNetv2" begin m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) @@ -159,8 +169,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "Xception" begin m = Xception() @test size(m(x_299)) == (1000, 2) @@ -171,11 +180,9 @@ GC.gc() end @test gradtest(m, x_299) end + _gc() end -GC.safepoint() -GC.gc() - @testset "SqueezeNet" begin m = SqueezeNet() @test size(m(x_224)) == (1000, 1) @@ -185,15 +192,12 @@ GC.gc() @test_throws ArgumentError SqueezeNet(pretrain = true) end @test gradtest(m, x_224) + _gc() end -GC.safepoint() -GC.gc() - @testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) if (DenseNet, sz) in PRETRAINED_MODELS @test acctest(DenseNet(sz, pretrain = true)) @@ -201,18 +205,13 @@ GC.gc() @test_throws ArgumentError DenseNet(sz, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "MobileNet" verbose = true begin @testset "MobileNetv1" begin m = MobileNetv1() - @test size(m(x_224)) == (1000, 1) if MobileNetv1 in PRETRAINED_MODELS @test acctest(MobileNetv1(pretrain = true)) @@ -221,8 +220,7 @@ GC.gc() end @test gradtest(m, x_224) end - GC.safepoint() - GC.gc() + _gc() @testset "MobileNetv2" begin m = MobileNetv2() @test size(m(x_224)) == (1000, 1) @@ -233,12 +231,10 @@ GC.gc() end @test gradtest(m, x_224) end - GC.safepoint() - GC.gc() + _gc() @testset "MobileNetv3" verbose = true begin @testset for mode in [:small, :large] m = MobileNetv3(mode) - @test size(m(x_224)) == (1000, 1) if (MobileNetv3, mode) in PRETRAINED_MODELS @test acctest(MobileNetv3(mode; pretrain = true)) @@ -246,12 +242,11 @@ GC.gc() @test_throws ArgumentError MobileNetv3(mode; pretrain = true) end @test gradtest(m, x_224) + _gc() end end end -GC.safepoint() -GC.gc() @testset "ConvNeXt" verbose = true begin @testset for mode in [:small, :base] #, :large # :tiny, #, :xlarge] @@ -259,22 +254,16 @@ GC.gc() m = ConvNeXt(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end -GC.safepoint() -GC.gc() - @testset "ConvMixer" verbose = true begin @testset for mode in [:small, :base] #, :large] m = ConvMixer(mode) - @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end diff --git a/test/other.jl b/test/other.jl index 3c1752f3a..df97d4f5f 100644 --- a/test/other.jl +++ b/test/other.jl @@ -4,8 +4,7 @@ m = MLPMixer(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end @@ -16,8 +15,7 @@ end m = ResMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end @@ -28,8 +26,7 @@ end m = gMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end diff --git a/test/runtests.jl b/test/runtests.jl index f1a9787b9..55e416ac2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,11 @@ const PRETRAINED_MODELS = [ (ResNet, 152), ] +function _gc() + GC.safepoint() + GC.gc() +end + function gradtest(model, input) y, pb = Zygote.pullback(() -> model(input), Flux.params(model)) gs = pb(ones(Float32, size(y))) diff --git a/test/vit-based.jl b/test/vit-based.jl index 9dc348819..e889b07be 100644 --- a/test/vit-based.jl +++ b/test/vit-based.jl @@ -3,7 +3,6 @@ m = ViT(mode) @test size(m(x_256)) == (1000, 1) @test gradtest(m, x_256) - GC.safepoint() - GC.gc() + _gc() end end From 3a4ffbfcd7491b3d76dd5db1a180cde7442e7ffa Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Mon, 4 Jul 2022 17:46:51 +0530 Subject: [PATCH 16/64] More docs, more tweaks --- src/Metalhead.jl | 4 +- src/convnets/convnext.jl | 22 ++-- src/convnets/resnets.jl | 270 ++++++++++++++++++++++++++++++++------- src/layers/Layers.jl | 6 +- src/layers/attention.jl | 4 +- src/layers/classifier.jl | 13 ++ src/layers/conv.jl | 43 ++----- src/layers/embeddings.jl | 2 +- src/layers/normalise.jl | 4 +- src/layers/pool.jl | 35 +++-- src/layers/selayers.jl | 1 - src/other/mlpmixer.jl | 6 +- src/vit-based/vit.jl | 2 +- 13 files changed, 297 insertions(+), 115 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index a4dd73785..e60eff405 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -49,9 +49,9 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt, +for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :SEResNet, :SEResNeXt, + :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index f3da6dbf3..6ced7eeb9 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -4,7 +4,7 @@ Creates a single block of ConvNeXt. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `planes`: number of input channels. - `drop_path_rate`: Stochastic depth rate. @@ -27,7 +27,7 @@ end Creates the layers for a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `inchannels`: number of input channels. - `depths`: list with configuration for depth of each block @@ -39,32 +39,29 @@ Creates the layers for a ConvNeXt model. """ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert length(depths)==length(planes) "`planes` should have exactly one value for each block" - + @assert length(depths) == length(planes) + "`planes` should have exactly one value for each block" downsample_layers = [] stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4), - ChannelLayerNorm(planes[1]; ϵ = 1.0f-6)) + ChannelLayerNorm(planes[1])) push!(downsample_layers, stem) for m in 1:(length(depths) - 1) - downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6), + downsample_layer = Chain(ChannelLayerNorm(planes[m]), Conv((2, 2), planes[m] => planes[m + 1]; stride = 2)) push!(downsample_layers, downsample_layer) end - stages = [] dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths)) cur = 0 - for i in 1:length(depths) + for i in eachindex(depths) push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]]) cur += depths[i] end - backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages)))) head = Chain(GlobalMeanPool(), MLUtils.flatten, LayerNorm(planes[end]), Dense(planes[end], nclasses)) - return Chain(Chain(backbone), head) end @@ -90,7 +87,7 @@ end Creates a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `inchannels`: The number of channels in the input. - `drop_path_rate`: Stochastic depth rate. @@ -101,7 +98,8 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))" + @assert mode in keys(convnext_configs) + "`size` must be one of $(collect(keys(convnext_configs)))" depths = convnext_configs[mode][:depths] planes = convnext_configs[mode][:planes] layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses) diff --git a/src/convnets/resnets.jl b/src/convnets/resnets.jl index 73f070617..f46ac29ad 100644 --- a/src/convnets/resnets.jl +++ b/src/convnets/resnets.jl @@ -1,31 +1,50 @@ -function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(outplanes)) -end +# resnet.jl +## It is recommended to check out the user's guide for more information +## regarding the use of these functions. + +### ResNet blocks +## These functions return a block to be used inside of a ResNet model. +## The individual arguments are explained in the documentation of the functions. +## Note that for these blocks to be used by the `_make_blocks` function, they must define +## a dispatch `expansion(::typeof(fn))` that returns the expansion factor of the block +## (i.e. the multiplicative factor by which the number of channels in the input is increased). +## The `_make_blocks` function will then call the `expansion` function to determine the +## expansion factor of each block and use this to construct the stages of the model. -function downsample_avg(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = MeanPool((2, 2); stride = avg_stride, pad) - end - return Chain(pool, - Conv((1, 1), inplanes => outplanes; bias = false), - norm_layer(outplanes)) -end +""" + basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) + +Creates a basic ResNet block. +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `downsample`: the downsampling function to use + - `cardinality`: redundant, kept for compatibility with `bottleneck`. + - `base_width`: redundant, kept for compatibility with `bottleneck`. + - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + convolution. + - `dilation`: the dilation of the second convolution. + - `first_dilation`: the dilation of the first convolution. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. + - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the + attention function. +""" function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) @@ -33,7 +52,6 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina @assert base_width==64 "`basicblock` does not support changing base width" first_planes = planes ÷ reduce_first outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, dilation = first_dilation, bias = false), norm_layer(first_planes)) @@ -49,16 +67,46 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina end expansion_factor(::typeof(basicblock)) = 1 +""" + bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = dilation, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) + +Creates a bottleneck ResNet block. + +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `downsample`: the downsampling function to use + - `cardinality`: the number of groups in the convolution. + - `base_width`: the number of output feature maps for each convolutional group. + - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + convolution. + - `dilation`: redundant, kept for compatibility with `basicblock`. + - `first_dilation`: the dilation of the 3x3 convolution. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. + - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the + attention function. +""" function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), norm_layer(first_planes, activation)) conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, @@ -77,17 +125,17 @@ expansion_factor(::typeof(bottleneck)) = 4 resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) -Builds a stem to be used in a ResNet model. See the `stem` argument of `resnet` for details +Builds a stem to be used in a ResNet model. See the `stem` argument of [`resnet`](#) for details on how to use this function. -# Arguments: +# Arguments - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + `:default`: Builds a stem based on the default ResNet stem, which consists of a single 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling layer with stride 2. - + `:deep`: This borrows ideas from other papers (InceptionResNet-v2 for one) in using a + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each one. This is followed by a 3x3 max pooling layer with stride 2. + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second @@ -137,13 +185,62 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = return Chain(conv1, bn1, stempool), inplanes end +### Downsampling layers +## These will almost never be used directly. They are used by the `_make_blocks` function to +## build the downsampling layers. In most cases, these defaults will not need to be changed. +## If you wish to write your own ResNet model using the `_make_blocks` function, you can use +## this function to build the downsampling layers. + +# Downsample layer using convolutions. +function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, + norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size + dilation = kernel_size[1] > 1 ? dilation : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, + dilation, bias = false), + norm_layer(outplanes)) +end + +# Downsample layer using max pooling +function downsample_pool(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, + norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = MeanPool((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), inplanes => outplanes; bias = false), + norm_layer(outplanes)) +end + +""" + downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), + stride = 1, dilation = 1, norm_layer = BatchNorm) + +Wrapper function that makes it easier to build a downsample block inside a ResNet model. +This function is almost never used directly or customised by the user. + +# Arguments + + - `downsample_fn`: The function to use for downsampling in skip connections. Recommended usage + is passing in either `downsample_conv` or `downsample_pool`. + - `inplanes`: The number of input feature maps. + - `planes`: The number of output feature maps. + - `expansion`: The expansion factor of the block. + - `kernel_size`: The size of the convolutional kernel. + - `stride`: The stride of the convolutional layer. + - `dilation`: The dilation of the convolutional layer. + - `norm_layer`: The normalisation layer to be used. +""" function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), - stride = 1, dilation = 1, first_dilation = dilation, - norm_layer = BatchNorm) + stride = 1, dilation = 1, norm_layer = BatchNorm) if stride != 1 || inplanes != planes * expansion downsample = downsample_fn(kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation, - norm_layer) + stride, dilation, norm_layer) else downsample = identity end @@ -166,6 +263,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # Stochastic depth linear decay rule (DropPath) dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), sum(block_repeats)) + # Construct each stage for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, _drop_blocks(get(drop_rates, @@ -181,10 +279,11 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride end # Downsample block; either a (default) convolution-based block or a pooling-based block downsample = downsample_block(downsample_fn, inplanes, planes, expansion; - stride, dilation, first_dilation = dilation) + stride, dilation) # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks + # Different behaviour for the first block of each stage downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 push!(blocks, @@ -209,7 +308,62 @@ function _drop_blocks(drop_block_prob = 0.0) ] end -function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32, +""" + resnet(block_type, layers; inchannels = 3, nclasses = 1000, output_stride = 32, + stem = first(resnet_stem(; inchannels)), inplanes = 64, + downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), + drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = 0.0), + classifier_args::NamedTuple = NamedTuple()) + +This function creates the layers for many ResNet-like models. + +!!! note + + If you are an end-user trying to use ResNet-like models, you should consider [`ResNet`](#) + and similar higher-level functions instead. This version is significantly more customisable + at the cost of being more complicated. + +# Arguments + + - `block_fn`: The type of block to use inside the ResNet model. Must be either `:basicblock`, + which is the standard ResNet block, or `:bottleneck`, which is the ResNet block with a + bottleneck structure. See the [paper](https://arxiv.org/abs/1512.03385) for more details. + + - `layers`: A list of integers specifying the number of blocks in each stage. For example, + `[3, 4, 6, 3]` would mean that the network would have 4 stages, with 3, 4, 6 and 3 blocks in + each. + - `nclasses`: The number of output classes. + - `inchannels`: The number of input channels. + - `output_stride`: The total stride of the network i.e. the amount by which the input is + downsampled throughout the network. This is used to determine the output size from the + backbone of the network. Must be one of `[8, 16, 32]`. + - `stem`: A constructed ResNet stem, passed in to be used in the model. `inplanes` should be + set to the number of output channels from this stem. Metalhead provides an in-built + function for creating a stem (see [`resnet_stem`](#)) but you can also create your + own (although this is not usually necessary). + - `inplanes`: The number of output channels from the stem. + - `downsample_type`: The type of downsampling to use. Either `:conv` or `:pool`. The former + uses a traditional convolution-based downsampling, while the latter is an + average-pooling-based downsampling that was suggested in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) + paper. + - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed + to the block function. For more information regarding valid arguments, see + the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). + - `drop_rates`: A `NamedTuple` that can may define none, some or all of the following: + + + `dropout_rate`: The rate of dropout to be used in the classifier head. + + `drop_path_rate`: Stochastic depth implemented using [`DropPath`](#). + + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). + - `classifier_args`: A `NamedTuple` that may define none, some or all of the following: + + + `pool_type`: The type of pooling to use in the classifier head. Uses + [`SelectAdaptivePool`](#) to select the pooling function. See its + documentation for more information. + + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a + `Dense` layer. +""" +function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, @@ -241,7 +395,7 @@ end (m::ResNet)(x) = m.layers(x) """ - ResNet(depth::Integer; pretrain = false, nclasses = 1000) + ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) Creates a ResNet model with the specified depth. ((reference)[https://arxiv.org/abs/1512.03385]) @@ -250,6 +404,7 @@ Creates a ResNet model with the specified depth. - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes !!! warning @@ -258,10 +413,10 @@ Creates a ResNet model with the specified depth. Advanced users who want more configuration options will be better served by using [`resnet`](#). """ -function ResNet(depth::Integer; pretrain = false, nclasses = 1000) +function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses) + layers = resnet(resnet_config[depth]...; inchannels, nclasses) if pretrain loadpretrain!(layers, string("resnet", depth)) end @@ -276,7 +431,8 @@ end (m::ResNeXt)(x) = m.layers(x) """ - ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) Creates a ResNeXt model with the specified depth, cardinality, and base width. ((reference)[https://arxiv.org/abs/1611.05431]) @@ -287,6 +443,7 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels. - `nclasses`: the number of output classes !!! warning @@ -295,11 +452,11 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. Advanced users who want more configuration options will be better served by using [`resnet`](#). """ -function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - nclasses = 1000) +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses, + layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; cardinality, base_width)) if pretrain loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width)) @@ -315,7 +472,7 @@ end (m::SEResNet)(x) = m.layers(x) """ - SEResNet(depth::Integer; pretrain = false, nclasses = 1000) + SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) Creates a SEResNet model with the specified depth. ((reference)[https://arxiv.org/pdf/1709.01507.pdf]) @@ -324,12 +481,19 @@ Creates a SEResNet model with the specified depth. - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: the number of input channels. - `nclasses`: the number of output classes + +!!! warning + + `SEResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). """ -function SEResNet(depth::Integer; pretrain = false, nclasses = 1000) +function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses, + layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; attn_fn = squeeze_excite)) if pretrain loadpretrain!(layers, string("seresnet", depth)) @@ -345,7 +509,8 @@ end (m::SEResNeXt)(x) = m.layers(x) """ - SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) Creates a SEResNeXt model with the specified depth, cardinality, and base width. ((reference)[https://arxiv.org/pdf/1709.01507.pdf]) @@ -356,13 +521,20 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels - `nclasses`: the number of output classes + +!!! warning + + `SEResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). """ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - nclasses = 1000) + inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses, + layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) if pretrain loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 41a98843e..a86361143 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -25,8 +25,7 @@ include("normalise.jl") export prenorm, ChannelLayerNorm include("conv.jl") -export conv_bn, depthwise_sep_conv_bn, invertedresidual -skip_identity, skip_projection +export conv_bn, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection include("drop.jl") export DropPath, DropBlock @@ -38,7 +37,6 @@ include("classifier.jl") export create_classifier include("pool.jl") -export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool -SelectAdaptivePool +export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool, SelectAdaptivePool end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 3cefe7c0d..7d8ee776d 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -3,7 +3,7 @@ Multi-head self-attention layer. -# Arguments: +# Arguments - `nheads`: Number of heads - `qkv_layer`: layer to be used for getting the query, key and value @@ -22,7 +22,7 @@ end Multi-head self-attention layer. -# Arguments: +# Arguments - `planes`: number of input channels - `nheads`: number of heads diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl index e2a1fe75c..0e9ba02d1 100644 --- a/src/layers/classifier.jl +++ b/src/layers/classifier.jl @@ -1,3 +1,16 @@ +""" + create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) + +Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `pool_type`: the type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, + `:catmeanmax` or `:identity`. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. +""" function create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling if pool_type == :identity diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e56967aef..7605a6cd1 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,8 +1,7 @@ """ conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, - groups = 1, [bias, weight, init], initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1) + rev = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, + groups = 1, [bias, weight, init]) Create a convolution + batch normalization pair with activation. @@ -22,13 +21,9 @@ Create a convolution + batch normalization pair with activation. - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) - - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) - - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) """ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, - initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1.0f-5, momentum = 1.0f-1, - kwargs...) + rev = false, preact = false, use_bn = true, kwargs...) if !use_bn (preact || rev) ? throw("preact only supported with `use_bn = true`") : return [Conv(kernelsize, inplanes => outplanes, activation; kwargs...)] @@ -48,17 +43,14 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...)) push!(layers, - BatchNorm(Int(bnplanes), activations.bn; - initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum)) + BatchNorm(Int(bnplanes), activations.bn)) return rev ? reverse(layers) : layers end """ depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; rev = false, use_bn = (true, true), - stride = 1, pad = 0, dilation = 1, [bias, weight, init], - initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1) + stride = 1, pad = 0, dilation = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: @@ -82,21 +74,13 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) - - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) - - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) """ function depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; rev = false, use_bn = (true, true), - initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1, stride = 1, kwargs...) return vcat(conv_bn(kernelsize, inplanes, inplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, - ϵ = ϵ, momentum = momentum, use_bn = use_bn[1], - stride = stride, groups = Int(inplanes), kwargs...), - conv_bn((1, 1), inplanes, outplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, use_bn = use_bn[2], - ϵ = ϵ, momentum = momentum)) + rev, use_bn = use_bn[1], stride, groups = Int(inplanes), kwargs...), + conv_bn((1, 1), inplanes, outplanes, activation; rev, use_bn = use_bn[2])) end """ @@ -105,10 +89,10 @@ end Create a skip projection ([reference](https://arxiv.org/abs/1512.03385v1)). -# Arguments: +# Arguments - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps - `downsample`: set to `true` to downsample the input """ function skip_projection(inplanes, outplanes, downsample = false) @@ -124,7 +108,7 @@ end Create a identity projection ([reference](https://arxiv.org/abs/1512.03385v1)). -# Arguments: +# Arguments - `inplanes`: the number of input feature maps - `outplanes`: the number of output feature maps @@ -153,15 +137,14 @@ Create a basic inverted residual block for MobileNet variants # Arguments - - `kernel_size`: The kernel size of the convolutional layers - - `inplanes`: The number of input feature maps + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps - `hidden_planes`: The number of feature maps in the hidden layer - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)). - Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; stride, reduction = nothing) diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index 66f25d1c0..3e85f18d9 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -8,7 +8,7 @@ _flatten_spatial(x) = permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1 Patch embedding layer used by many vision transformer-like models to split the input image into patches. -# Arguments: +# Arguments - `imsize`: the size of the input image - `inchannels`: the number of channels in the input. diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c767bd1e0..e71634e22 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,7 +2,7 @@ prenorm(planes, fn) = Chain(LayerNorm(planes), fn) """ - ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1f-5) + ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) A variant of LayerNorm where the input is normalised along the channel dimension. The input is expected to have channel dimension with size @@ -19,7 +19,7 @@ end @functor ChannelLayerNorm -function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) +function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) diag = Flux.Scale(1, 1, sz, λ) return ChannelLayerNorm(diag, ϵ) end diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 4ffe298e3..36a08a8da 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,21 +1,40 @@ -function AdaptiveMeanMaxPool(output_size = (1, 1)) - return 0.5 * Parallel(.+, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) -end +""" + AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) + +A type of adaptive pooling layer which uses both mean and max pooling and combines them to +produce a single output. Note that this is equivalent to +`Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))` + +# Arguments -function AdaptiveCatMeanMaxPool(output_size = (1, 1)) - return Parallel(cat_channels, AdaptiveAvgMaxPool(output_size), - AdaptiveMaxPool(output_size)) + - `output_size`: The size of the output after pooling. + - `connection`: The connection type to use. +""" +function AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) + return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) end +""" + SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) + +Adaptive pooling factory function. + +# Arguments + + - `output_size`: The size of the output after pooling. + - `pool_type`: The type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, + `:catmeanmax` or `:identity`. + - `flatten`: Whether to flatten the output from the pooling layer. +""" function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) if pool_type == :mean pool = AdaptiveMeanPool(output_size) elseif pool_type == :max pool = AdaptiveMaxPool(output_size) elseif pool_type == :meanmax - pool = AdaptiveMeanMaxPool(output_size) + pool = 0.5f0 * AdaptiveMeanMaxPool(output_size) elseif pool_type == :catmeanmax - pool = AdaptiveCatMeanMaxPool(output_size) + pool = AdaptiveMeanMaxPool(output_size; connection = cat_channels) elseif pool_type == :identity pool = identity else diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 7f1a76d59..6d86947c9 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -14,7 +14,6 @@ Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer - Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, activation = relu, gate_activation = sigmoid, diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 5083b228e..aab7cba4e 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -5,7 +5,7 @@ Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) -# Arguments: +# Arguments - `planes`: the number of planes in the block - `npatches`: the number of patches of the input @@ -55,8 +55,8 @@ Creates a model with the MLPMixer architecture. not specified. """ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, - norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0.0, + norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), + embedplanes = 512, drop_path_rate = 0.0, depth = 12, nclasses = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) dp_rates = LinRange{Float32}(0.0, drop_path_rate, depth) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index a06ce6886..b3a7e167c 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -50,7 +50,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] - "Pool type must be either :class (class token) or :mean (mean pooling)" + "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), From 2f755cf7b28c1df60712ec2744db82e1a95f42bf Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 8 Jul 2022 21:54:48 +0530 Subject: [PATCH 17/64] More aggressive GC Co-authored-by: Brian Chen --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 55e416ac2..1a8c77f25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ const PRETRAINED_MODELS = [ function _gc() GC.safepoint() - GC.gc() + GC.gc(true) end function gradtest(model, input) From 5ba4b84f8f12188fc8243510c68ce0c6295631c8 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 9 Jul 2022 09:35:31 +0530 Subject: [PATCH 18/64] Tweaks don't stop Neither does formatting, unfortunately. Also refactor `classifier` to separate out FC-layer creation and pooling --- src/convnets/inception.jl | 8 +-- src/convnets/resnets.jl | 78 ++++++++++++++-------------- src/layers/Layers.jl | 10 ++-- src/layers/classifier.jl | 25 --------- src/layers/{mlp-linear.jl => mlp.jl} | 33 ++++++------ src/layers/pool.jl | 34 +----------- src/layers/scale.jl | 24 +++++++++ src/other/mlpmixer.jl | 4 +- src/utilities.jl | 10 ---- src/vit-based/vit.jl | 2 +- 10 files changed, 95 insertions(+), 133 deletions(-) delete mode 100644 src/layers/classifier.jl rename src/layers/{mlp-linear.jl => mlp.jl} (79%) create mode 100644 src/layers/scale.jl diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index 156362cf3..5823e9737 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -425,7 +425,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -458,7 +458,7 @@ function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000 end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -542,7 +542,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) + xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -573,7 +573,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) diff --git a/src/convnets/resnets.jl b/src/convnets/resnets.jl index f46ac29ad..a70fc7f07 100644 --- a/src/convnets/resnets.jl +++ b/src/convnets/resnets.jl @@ -1,6 +1,5 @@ # resnet.jl -## It is recommended to check out the user's guide for more information -## regarding the use of these functions. +## It is recommended to check out the user guide for more information. ### ResNet blocks ## These functions return a block to be used inside of a ResNet model. @@ -12,10 +11,9 @@ ## expansion factor of each block and use this to construct the stages of the model. """ - basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, + basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, + dilation = 1, first_dilation = dilation, activation = relu, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) Creates a basic ResNet block. @@ -26,8 +24,6 @@ Creates a basic ResNet block. - `planes`: number of feature maps for the block - `stride`: the stride of the block - `downsample`: the downsampling function to use - - `cardinality`: redundant, kept for compatibility with `bottleneck`. - - `base_width`: redundant, kept for compatibility with `bottleneck`. - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first convolution. - `dilation`: the dilation of the second convolution. @@ -42,14 +38,11 @@ Creates a basic ResNet block. - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the attention function. """ -function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = dilation, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, +function basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, + dilation = 1, first_dilation = dilation, activation = relu, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) - @assert cardinality==1 "`basicblock` only supports cardinality of 1" - @assert base_width==64 "`basicblock` does not support changing base width" first_planes = planes ÷ reduce_first outplanes = planes * expansion conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, @@ -69,8 +62,8 @@ expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = dilation, activation = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, first_dilation = 1, + activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -86,7 +79,6 @@ Creates a bottleneck ResNet block. - `base_width`: the number of output feature maps for each convolutional group. - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `dilation`: redundant, kept for compatibility with `basicblock`. - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. @@ -99,8 +91,8 @@ Creates a bottleneck ResNet block. attention function. """ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = dilation, activation = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, first_dilation = 1, + activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) @@ -263,12 +255,12 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # Stochastic depth linear decay rule (DropPath) dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), sum(block_repeats)) - # Construct each stage - for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, - block_repeats, - _drop_blocks(get(drop_rates, - :drop_block_rate, - 0)))) + # DropBlock rate + dbr = get(drop_rates, :drop_block_rate, 0) + ## Construct each stage + for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr))) + # Number of planes in each stage, number of blocks in each stage, and the drop block rate + planes, num_blocks, drop_block = itr # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -280,7 +272,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # Downsample block; either a (default) convolution-based block or a pooling-based block downsample = downsample_block(downsample_fn, inplanes, planes, expansion; stride, dilation) - # Construct the blocks for each stage + ## Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks # Different behaviour for the first block of each stage @@ -309,14 +301,16 @@ function _drop_blocks(drop_block_prob = 0.0) end """ - resnet(block_type, layers; inchannels = 3, nclasses = 1000, output_stride = 32, + resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), - classifier_args::NamedTuple = NamedTuple()) + classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), + use_conv = false)) -This function creates the layers for many ResNet-like models. +This function creates the layers for many ResNet-like models. See the user guide for more +information. !!! note @@ -350,16 +344,14 @@ This function creates the layers for many ResNet-like models. - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed to the block function. For more information regarding valid arguments, see the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). - - `drop_rates`: A `NamedTuple` that can may define none, some or all of the following: + - `drop_rates`: A `NamedTuple` that may define none, some or all of the following: + `dropout_rate`: The rate of dropout to be used in the classifier head. + `drop_path_rate`: Stochastic depth implemented using [`DropPath`](#). + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). - - `classifier_args`: A `NamedTuple` that may define none, some or all of the following: + - `classifier_args`: A `NamedTuple` that **must** specify the following arguments: - + `pool_type`: The type of pooling to use in the classifier head. Uses - [`SelectAdaptivePool`](#) to select the pooling function. See its - documentation for more information. + + `pool_layer`: The adaptive pooling layer to use in the classifier head. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ @@ -368,15 +360,25 @@ function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), - classifier_args::NamedTuple = NamedTuple()) - # Feature Blocks + classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), + use_conv = false)) + ## Feature Blocks channels = [64, 128, 256, 512] stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fn, drop_rates, block_args) - # Head (Pooling and Classifier) + ## Classifier head expansion = expansion_factor(block_fn) num_features = 512 * expansion - global_pool, fc = create_classifier(num_features, nclasses; classifier_args...) + pool_layer, use_conv = classifier_args + # Pooling + if pool_layer === identity + @assert use_conv + "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" + end + flatten_in_pool = !use_conv && pool_layer !== identity + global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer + # Fully-connected layer + fc = create_fc(num_features, nclasses; use_conv) classifier = Chain(global_pool, Dropout(get(drop_rates, :dropout_rate, 0)), fc) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index a86361143..2c4b11e5a 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -18,8 +18,8 @@ export MHAttention include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens -include("mlp-linear.jl") -export mlp_block, gated_mlp_block, LayerScale +include("mlp.jl") +export mlp_block, gated_mlp_block, create_fc include("normalise.jl") export prenorm, ChannelLayerNorm @@ -33,10 +33,10 @@ export DropPath, DropBlock include("selayers.jl") export squeeze_excite, effective_squeeze_excite -include("classifier.jl") -export create_classifier +include("scale.jl") +export LayerScale, inputscale include("pool.jl") -export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool, SelectAdaptivePool +export AdaptiveMeanMaxPool end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl deleted file mode 100644 index 0e9ba02d1..000000000 --- a/src/layers/classifier.jl +++ /dev/null @@ -1,25 +0,0 @@ -""" - create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) - -Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. - -# Arguments - - - `inplanes`: number of input feature maps - - `nclasses`: number of output classes - - `pool_type`: the type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, - `:catmeanmax` or `:identity`. - - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. -""" -function create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) - flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling - if pool_type == :identity - @assert use_conv - "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" - flatten_in_pool = false # disable flattening if pooling is pass-through (no pooling) - end - global_pool = SelectAdaptivePool(; pool_type, flatten = flatten_in_pool) - fc = use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : - Dense(inplanes => nclasses; bias = true) - return global_pool, fc -end diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp.jl similarity index 79% rename from src/layers/mlp-linear.jl rename to src/layers/mlp.jl index 8cca1e266..f72520451 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp.jl @@ -1,21 +1,6 @@ -""" - LayerScale(λ, planes::Integer) - -Creates a `Flux.Scale` layer that performs "`LayerScale`" -([reference](https://arxiv.org/abs/2103.17239)). - -# Arguments - - - `planes`: Size of channel dimension in the input. - - `λ`: initialisation value for the learnable diagonal matrix. -""" -function LayerScale(planes::Integer, λ) - return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity -end - """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout_rate =0., activation = gelu) + dropout_rate = 0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -60,3 +45,19 @@ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) + +""" + create_fc(inplanes, nclasses; use_conv = false) + +Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. +""" +function create_fc(inplanes, nclasses; use_conv = false) + return use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : + Dense(inplanes => nclasses; bias = true) +end diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 36a08a8da..0a74a24c0 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,5 +1,5 @@ """ - AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) + AdaptiveMeanMaxPool(output_size = (1, 1); connection = +) A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to @@ -10,36 +10,6 @@ produce a single output. Note that this is equivalent to - `output_size`: The size of the output after pooling. - `connection`: The connection type to use. """ -function AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) +function AdaptiveMeanMaxPool(output_size = (1, 1); connection = +) return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) end - -""" - SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) - -Adaptive pooling factory function. - -# Arguments - - - `output_size`: The size of the output after pooling. - - `pool_type`: The type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, - `:catmeanmax` or `:identity`. - - `flatten`: Whether to flatten the output from the pooling layer. -""" -function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) - if pool_type == :mean - pool = AdaptiveMeanPool(output_size) - elseif pool_type == :max - pool = AdaptiveMaxPool(output_size) - elseif pool_type == :meanmax - pool = 0.5f0 * AdaptiveMeanMaxPool(output_size) - elseif pool_type == :catmeanmax - pool = AdaptiveMeanMaxPool(output_size; connection = cat_channels) - elseif pool_type == :identity - pool = identity - else - throw(AssertionError("Invalid pool type: $pool_type")) - end - flatten_fn = flatten ? MLUtils.flatten : identity - return Chain(pool, flatten_fn) -end diff --git a/src/layers/scale.jl b/src/layers/scale.jl new file mode 100644 index 000000000..cd55fc97c --- /dev/null +++ b/src/layers/scale.jl @@ -0,0 +1,24 @@ +""" + inputscale(λ; activation = identity) + +Scale the input by a scalar `λ` and applies an activation function to it. +Equivalent to `activation.(λ .* x)`. +""" +inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) +_input_scale(x, λ, activation) = activation.(λ .* x) +_input_scale(x, λ, ::typeof(identity)) = λ .* x + +""" + LayerScale(λ, planes::Integer) + +Creates a `Flux.Scale` layer that performs "`LayerScale`" +([reference](https://arxiv.org/abs/2103.17239)). + +# Arguments + + - `planes`: Size of channel dimension in the input. + - `λ`: initialisation value for the learnable diagonal matrix. +""" +function LayerScale(planes::Integer, λ) + return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity +end diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index aab7cba4e..48f1efd8c 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout_rate =0., drop_path_rate = 0., activation = gelu) + dropout_rate = 0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -115,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; dropout_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. diff --git a/src/utilities.jl b/src/utilities.jl index 930cc621a..9c29350bd 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -18,16 +18,6 @@ Convenient reduction operator for use with `Parallel`. """ cat_channels(xy...) = cat(xy...; dims = Val(3)) -""" - inputscale(λ; activation = identity) - -Scale the input by a scalar `λ` and applies an activation function to it. -Equivalent to `activation.(λ .* x)`. -""" -inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) -_input_scale(x, λ, activation) = activation.(λ .* x) -_input_scale(x, λ, ::typeof(identity)) = λ .* x - """ swapdims(perm) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index b3a7e167c..856b64697 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate =0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). From aaf2abb41c4669b23d135fa55d2159af056207e9 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 9 Jul 2022 18:06:28 +0530 Subject: [PATCH 19/64] Reorganisation and formatting It really does never stop Co-Authored-By: Kyle Daruwalla --- src/Metalhead.jl | 6 +- src/convnets/{resnets.jl => resnets/core.jl} | 172 +------------------ src/convnets/resnets/resnet.jl | 35 ++++ src/convnets/resnets/resnext.jl | 40 +++++ src/convnets/resnets/seresnet.jl | 77 +++++++++ test/convnets.jl | 6 +- 6 files changed, 169 insertions(+), 167 deletions(-) rename src/convnets/{resnets.jl => resnets/core.jl} (77%) create mode 100644 src/convnets/resnets/resnet.jl create mode 100644 src/convnets/resnets/resnext.jl create mode 100644 src/convnets/resnets/seresnet.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index e60eff405..e88279270 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,13 +22,17 @@ include("convnets/alexnet.jl") include("convnets/vgg.jl") include("convnets/inception.jl") include("convnets/googlenet.jl") -include("convnets/resnets.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") +## ResNets +include("convnets/resnets/core.jl") +include("convnets/resnets/resnet.jl") +include("convnets/resnets/resnext.jl") +include("convnets/resnets/seresnet.jl") # Other models include("other/mlpmixer.jl") diff --git a/src/convnets/resnets.jl b/src/convnets/resnets/core.jl similarity index 77% rename from src/convnets/resnets.jl rename to src/convnets/resnets/core.jl index a70fc7f07..c53575c02 100644 --- a/src/convnets/resnets.jl +++ b/src/convnets/resnets/core.jl @@ -1,4 +1,3 @@ -# resnet.jl ## It is recommended to check out the user guide for more information. ### ResNet blocks @@ -11,7 +10,7 @@ ## expansion factor of each block and use this to construct the stages of the model. """ - basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, + basicblock(inplanes, planes; stride = 1, downsample = identity, reduction_factor = 1, dilation = 1, first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -24,7 +23,7 @@ Creates a basic ResNet block. - `planes`: number of feature maps for the block - `stride`: the stride of the block - `downsample`: the downsampling function to use - - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - `dilation`: the dilation of the second convolution. - `first_dilation`: the dilation of the first convolution. @@ -38,12 +37,13 @@ Creates a basic ResNet block. - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the attention function. """ -function basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, +function basicblock(inplanes, planes; stride = 1, downsample = identity, + reduction_factor = 1, dilation = 1, first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) - first_planes = planes ÷ reduce_first + first_planes = planes ÷ reduction_factor outplanes = planes * expansion conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, dilation = first_dilation, bias = false), @@ -62,7 +62,7 @@ expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, first_dilation = 1, + base_width = 64, reduction_factor = 1, first_dilation = 1, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -77,7 +77,7 @@ Creates a bottleneck ResNet block. - `downsample`: the downsampling function to use - `cardinality`: the number of groups in the convolution. - `base_width`: the number of output feature maps for each convolutional group. - - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. @@ -91,13 +91,13 @@ Creates a bottleneck ResNet block. attention function. """ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, first_dilation = 1, + base_width = 64, reduction_factor = 1, first_dilation = 1, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality - first_planes = width ÷ reduce_first + first_planes = width ÷ reduction_factor outplanes = planes * expansion conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), norm_layer(first_planes, activation)) @@ -389,157 +389,3 @@ const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) -struct ResNet - layers::Any -end -@functor ResNet - -(m::ResNet)(x) = m.layers(x) - -""" - ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - -Creates a ResNet model with the specified depth. -((reference)[https://arxiv.org/abs/1512.03385]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `inchannels`: The number of input channels. - - `nclasses`: the number of output classes - -!!! warning - - `ResNet` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("resnet", depth)) - end - return ResNet(layers) -end - -struct ResNeXt - layers::Any -end -@functor ResNeXt - -(m::ResNeXt)(x) = m.layers(x) - -""" - ResNeXt(depth::Integer; pretrain = false, cardinality = 32, - base_width = 4, inchannels = 3, nclasses = 1000) - -Creates a ResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/abs/1611.05431]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - - `base_width`: the number of feature maps in each group. - - `inchannels`: the number of input channels. - - `nclasses`: the number of output classes - -!!! warning - - `ResNeXt` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, - base_width = 4, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width)) - if pretrain - loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width)) - end - return ResNeXt(layers) -end - -struct SEResNet - layers::Any -end -@functor SEResNet - -(m::SEResNet)(x) = m.layers(x) - -""" - SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - -Creates a SEResNet model with the specified depth. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `inchannels`: the number of input channels. - - `nclasses`: the number of output classes - -!!! warning - - `SEResNet` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; attn_fn = squeeze_excite)) - if pretrain - loadpretrain!(layers, string("seresnet", depth)) - end - return SEResNet(layers) -end - -struct SEResNeXt - layers::Any -end -@functor SEResNeXt - -(m::SEResNeXt)(x) = m.layers(x) - -""" - SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - inchannels = 3, nclasses = 1000) - -Creates a SEResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - - `base_width`: the number of feature maps in each group. - - `inchannels`: the number of input channels - - `nclasses`: the number of output classes - -!!! warning - - `SEResNeXt` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) - if pretrain - loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) - end - return SEResNeXt(layers) -end diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl new file mode 100644 index 000000000..cd26e69a3 --- /dev/null +++ b/src/convnets/resnets/resnet.jl @@ -0,0 +1,35 @@ +""" + ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a ResNet model with the specified depth. +((reference)[https://arxiv.org/abs/1512.03385]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `ResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct ResNet + layers::Any +end +@functor ResNet + +(m::ResNet)(x) = m.layers(x) + +function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("ResNet", depth)) + end + return ResNet(layers) +end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl new file mode 100644 index 000000000..1fa00a7b0 --- /dev/null +++ b/src/convnets/resnets/resnext.jl @@ -0,0 +1,40 @@ +""" + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) + +Creates a ResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/abs/1611.05431]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `ResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct ResNeXt + layers::Any +end +@functor ResNeXt + +(m::ResNeXt)(x) = m.layers(x) + +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + block_args = (; cardinality, base_width)) + if pretrain + loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) + end + return ResNeXt(layers) +end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl new file mode 100644 index 000000000..58c6a0607 --- /dev/null +++ b/src/convnets/resnets/seresnet.jl @@ -0,0 +1,77 @@ +""" + SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a SEResNet model with the specified depth. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `SEResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct SEResNet + layers::Any +end +@functor SEResNet + +(m::SEResNet)(x) = m.layers(x) + +function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + block_args = (; attn_fn = squeeze_excite)) + if pretrain + loadpretrain!(layers, string("SEResNet", depth)) + end + return SEResNet(layers) +end + +""" + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) + +Creates a SEResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels + - `nclasses`: the number of output classes + +!!! warning + + `SEResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct SEResNeXt + layers::Any +end +@functor SEResNeXt + +(m::SEResNeXt)(x) = m.layers(x) + +function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) + if pretrain + loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) + end + return SEResNeXt(layers) +end diff --git a/test/convnets.jl b/test/convnets.jl index b717574c1..f0f86a2c1 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -62,7 +62,7 @@ end @testset for base_width in [4, 8] m = ResNeXt(depth; cardinality, base_width) @test size(m(x_224)) == (1000, 1) - if string("resnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS @test acctest(ResNeXt(depth, pretrain = true)) else @test_throws ArgumentError ResNeXt(depth, pretrain = true) @@ -78,7 +78,7 @@ end @testset for depth in [18, 34, 50, 101, 152] m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) - if string("seresnet", depth) in PRETRAINED_MODELS + if (SEResNet, depth) in PRETRAINED_MODELS @test acctest(SEResNet(depth, pretrain = true)) else @test_throws ArgumentError SEResNet(depth, pretrain = true) @@ -94,7 +94,7 @@ end @testset for base_width in [4, 8] m = SEResNeXt(depth; cardinality, base_width) @test size(m(x_224)) == (1000, 1) - if string("seresnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS @test acctest(SEResNeXt(depth, pretrain = true)) else @test_throws ArgumentError SEResNeXt(depth, pretrain = true) From 326f36cb33d518a643dde069358830b1f968d92b Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 9 Jul 2022 23:09:06 +0530 Subject: [PATCH 20/64] Refactor shortcut connections --- Project.toml | 1 + src/Metalhead.jl | 1 + src/convnets/inception.jl | 2 +- src/convnets/resnets/core.jl | 114 ++++++++++++++++++++++++----------- src/layers/scale.jl | 6 +- src/utilities.jl | 22 +++++++ 6 files changed, 108 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index 546717e39..0ad3086d0 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/Metalhead.jl b/src/Metalhead.jl index e88279270..67731825e 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,6 +7,7 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils +using PartialFunctions using Random import Functors diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index 5823e9737..ba9c935f6 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -530,7 +530,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, inc = inchannels outc = i == nrepeats ? outchannels : inchannels end - push!(layers, x -> relu.(x)) + push!(layers, relu) append!(layers, depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, use_bn = (false, false))) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index c53575c02..295a7698f 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -10,8 +10,9 @@ ## expansion factor of each block and use this to construct the stages of the model. """ - basicblock(inplanes, planes; stride = 1, downsample = identity, reduction_factor = 1, - dilation = 1, first_dilation = dilation, activation = relu, + basicblock(inplanes, planes; stride = 1, downsample = identity, + reduction_factor = 1, dilation = 1, first_dilation = dilation, + activation = relu, connection = addact\$activation, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -28,6 +29,9 @@ Creates a basic ResNet block. - `dilation`: the dilation of the second convolution. - `first_dilation`: the dilation of the first convolution. - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -38,8 +42,8 @@ Creates a basic ResNet block. attention function. """ function basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, - dilation = 1, first_dilation = dilation, activation = relu, + reduction_factor = 1, dilation = 1, first_dilation = dilation, + activation = relu, connection = addact$activation, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) @@ -53,18 +57,17 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, dilation = dilation, bias = false), norm_layer(outplanes)) attn_layer = attn_fn(outplanes; attn_args...) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, - drop_path)), - activation) + return Parallel(connection, downsample, + Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, + drop_path)) end expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, + activation = relu, connection = addact\$activation, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) Creates a bottleneck ResNet block. @@ -81,6 +84,9 @@ Creates a bottleneck ResNet block. convolution. - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -92,8 +98,8 @@ Creates a bottleneck ResNet block. """ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, + activation = relu, connection = addact$activation, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality @@ -106,10 +112,9 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina norm_layer(width)) conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) attn_layer = attn_fn(outplanes; attn_args...) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, - attn_layer, drop_path)), - activation) + return Parallel(connection, downsample, + Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, + attn_layer, drop_path)) end expansion_factor(::typeof(bottleneck)) = 4 @@ -209,6 +214,21 @@ function downsample_pool(kernel_size, inplanes, outplanes; stride = 1, dilation norm_layer(outplanes)) end +# Downsample layer which is an identity projection. Uses max pooling +# when the output size is more than the input size. +function downsample_identity(kernel_size, inplanes, outplanes; kwargs...) + if outplanes > inplanes + return Chain(MaxPool((1, 1); stride = 2), + y -> cat_channels(y, + zeros(eltype(y), + size(y, 1), + size(y, 2), + outplanes - inplanes, size(y, 4)))) + else + return identity + end +end + """ downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), stride = 1, dilation = 1, norm_layer = BatchNorm) @@ -228,24 +248,47 @@ This function is almost never used directly or customised by the user. - `dilation`: The dilation of the convolutional layer. - `norm_layer`: The normalisation layer to be used. """ -function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), - stride = 1, dilation = 1, norm_layer = BatchNorm) +function downsample_block(downsample_fns, inplanes, planes, expansion; + kernel_size = (1, 1), stride = 1, dilation = 1, + norm_layer = BatchNorm) + down_fn1, down_fn2 = downsample_fns if stride != 1 || inplanes != planes * expansion - downsample = downsample_fn(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) + downsample = down_fn2(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) else - downsample = identity + downsample = down_fn1(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) end return downsample end +const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv)) + +function _make_downsample_fns(vec::Vector{T}) where {T} + if T <: Symbol + downs = [] + for i in vec + @assert i in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + push!(downs, shortcut_dict[i]) + end + return downs + elseif T <: NTuple{2} + return vec + else + throw(ArgumentError("The shortcut list must be a `Vector` of `Symbol`s or `NTuple{2}`s")) + end +end + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, - downsample_fn = downsample_conv, - drop_rates::NamedTuple, block_args::NamedTuple) + downsample_fns::Vector, drop_rates::NamedTuple, + block_args::NamedTuple) @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" expansion = expansion_factor(block_fn) stages = [] @@ -258,9 +301,10 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # DropBlock rate dbr = get(drop_rates, :drop_block_rate, 0) ## Construct each stage - for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr))) + for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr), + downsample_fns)) # Number of planes in each stage, number of blocks in each stage, and the drop block rate - planes, num_blocks, drop_block = itr + planes, num_blocks, drop_block, down_fns = itr # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -270,7 +314,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride net_stride *= stride end # Downsample block; either a (default) convolution-based block or a pooling-based block - downsample = downsample_block(downsample_fn, inplanes, planes, expansion; + downsample = downsample_block(down_fns, inplanes, planes, expansion; stride, dilation) ## Construct the blocks for each stage blocks = [] @@ -355,17 +399,19 @@ information. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ -function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, +function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; + inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, - downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), + block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false)) ## Feature Blocks channels = [64, 128, 256, 512] + downsample_fns = _make_downsample_fns(downsample_list) stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; - output_stride, downsample_fn, drop_rates, block_args) + output_stride, downsample_fns, drop_rates, block_args) ## Classifier head expansion = expansion_factor(block_fn) num_features = 512 * expansion @@ -384,8 +430,8 @@ function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride end # block-layer configurations for ResNet and ResNeXt models -const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), - 34 => (basicblock, [3, 4, 6, 3]), - 50 => (bottleneck, [3, 4, 6, 3]), - 101 => (bottleneck, [3, 4, 23, 3]), - 152 => (bottleneck, [3, 8, 36, 3])) +const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2], [:A, :B, :B, :B]), + 34 => (basicblock, [3, 4, 6, 3], [:A, :B, :B, :B]), + 50 => (bottleneck, [3, 4, 6, 3], [:B, :B, :B, :B]), + 101 => (bottleneck, [3, 4, 23, 3], [:B, :B, :B, :B]), + 152 => (bottleneck, [3, 8, 36, 3], [:B, :B, :B, :B])) diff --git a/src/layers/scale.jl b/src/layers/scale.jl index cd55fc97c..965b50f38 100644 --- a/src/layers/scale.jl +++ b/src/layers/scale.jl @@ -4,9 +4,9 @@ Scale the input by a scalar `λ` and applies an activation function to it. Equivalent to `activation.(λ .* x)`. """ -inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) -_input_scale(x, λ, activation) = activation.(λ .* x) -_input_scale(x, λ, ::typeof(identity)) = λ .* x +inputscale(λ; activation = identity) = _input_scale$(λ, activation) +_input_scale(λ, activation, x) = activation.(λ .* x) +_input_scale(λ, ::typeof(identity), x) = λ .* x """ LayerScale(λ, planes::Integer) diff --git a/src/utilities.jl b/src/utilities.jl index 9c29350bd..938c598f0 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -9,6 +9,28 @@ function _round_channels(channels, divisor, min_value = divisor) return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels end +""" + addact(activation = relu, xs...) + +Convenience function for applying an activation function to the output after +summing up the input arrays. Useful as the `connection` argument for the block +function in [`resnet`](#). + +See also [`reluadd`](#). +""" +addact(activation = relu, xs...) = activation(sum(tuple(xs...))) + +""" + actadd(activation = relu, xs...) + +Convenience function for adding input arrays after applying an activation +function to them. Useful as the `connection` argument for the block function in +[`resnet`](#). + +See also [`addrelu`](#). +""" +actadd(activation = relu, xs...) = sum(activation.(tuple(xs...))) + """ cat_channels(x, y, zs...) From 4e01443ee59f1e1f8a4df1b56dea09a23ba97a9f Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 10 Jul 2022 10:27:12 +0530 Subject: [PATCH 21/64] Generalise `resnet` further --- src/convnets/resnets/core.jl | 17 +++++++++-------- src/convnets/resnets/resnet.jl | 8 +++++++- src/convnets/resnets/seresnet.jl | 8 ++++---- src/layers/Layers.jl | 1 + test/convnets.jl | 2 +- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 295a7698f..289a7812d 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -399,7 +399,8 @@ information. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ -function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; +function resnet(block_fn, layers, + downsample_list::Vector = collect(:B for _ in 1:length(layers)); inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, block_args::NamedTuple = NamedTuple(), @@ -408,7 +409,7 @@ function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false)) ## Feature Blocks - channels = [64, 128, 256, 512] + channels = collect(64 * 2^i for i in range(0, length(layers))) downsample_fns = _make_downsample_fns(downsample_list) stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fns, drop_rates, block_args) @@ -429,9 +430,9 @@ function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; return Chain(Chain(stem, stage_blocks), classifier) end -# block-layer configurations for ResNet and ResNeXt models -const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2], [:A, :B, :B, :B]), - 34 => (basicblock, [3, 4, 6, 3], [:A, :B, :B, :B]), - 50 => (bottleneck, [3, 4, 6, 3], [:B, :B, :B, :B]), - 101 => (bottleneck, [3, 4, 23, 3], [:B, :B, :B, :B]), - 152 => (bottleneck, [3, 8, 36, 3], [:B, :B, :B, :B])) +# block-layer configurations for ResNet-like models +const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index cd26e69a3..3356ef225 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -1,3 +1,9 @@ +const resnet_shortcuts = Dict(18 => [:A, :B, :B, :B], + 34 => [:A, :B, :B, :B], + 50 => [:B, :B, :B, :B], + 101 => [:B, :B, :B, :B], + 152 => [:B, :B, :B, :B]) + """ ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @@ -27,7 +33,7 @@ end function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses) + layers = resnet(resnet_config[depth]..., resnet_shortcuts[depth]; inchannels, nclasses) if pretrain loadpretrain!(layers, string("ResNet", depth)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 58c6a0607..605c074d6 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -6,7 +6,7 @@ Creates a SEResNet model with the specified depth. # Arguments - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: the number of input channels. - `nclasses`: the number of output classes @@ -25,8 +25,8 @@ end (m::SEResNet)(x) = m.layers(x) function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; attn_fn = squeeze_excite)) if pretrain @@ -44,7 +44,7 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. # Arguments - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 2c4b11e5a..e0b870fe9 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -8,6 +8,7 @@ using Functors using ChainRulesCore using Statistics using MLUtils +using PartialFunctions using Random include("../utilities.jl") diff --git a/test/convnets.jl b/test/convnets.jl index f0f86a2c1..6f5813d75 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -75,7 +75,7 @@ end end @testset "SEResNet" begin - @testset for depth in [18, 34, 50, 101, 152] + @testset for depth in [50, 101, 152] m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) if (SEResNet, depth) in PRETRAINED_MODELS From e8d348867a4876e07a4e94428b0e541fff23ecdd Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 10 Jul 2022 14:20:43 +0530 Subject: [PATCH 22/64] Documentation And make `downsample_opts` a smidge easier to work with. Also, a wee bit o' formatting and cleanup. --- src/convnets/resnets/core.jl | 94 ++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 289a7812d..e55eb7a76 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -130,15 +130,14 @@ on how to use this function. - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + `:default`: Builds a stem based on the default ResNet stem, which consists of a single - 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 - max pooling layer with stride 2. - + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using a - deeper stem with 3 successive 3x3 convolutions having normalisation layers - after each one. This is followed by a 3x3 max pooling layer with stride 2. + 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling + layer with stride 2. + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using + a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each + one. This is followed by a 3x3 max pooling layer with stride 2. + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second - convolution. This is an experimental variant from the `timm` library - in Python that shows peformance improvements over the `:deep` stem - in some cases. + convolution. This is an experimental variant from the `timm` library in Python that + shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - `replace_stem_pool`: Whether to replace the default 3x3 max pooling layer with a @@ -253,20 +252,27 @@ function downsample_block(downsample_fns, inplanes, planes, expansion; norm_layer = BatchNorm) down_fn1, down_fn2 = downsample_fns if stride != 1 || inplanes != planes * expansion - downsample = down_fn2(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) + return down_fn1(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) else - downsample = down_fn1(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) + return down_fn2(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) end - return downsample end +# Shortcut configurations for the ResNet models const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv)) - -function _make_downsample_fns(vec::Vector{T}) where {T} + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) + +# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is +# specified as a `Vector` of `Symbol`s. This is used to make the downsample +# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is +# already an `NTuple{2}` of functions, it is returned unchanged. +function _make_downsample_fns(vec::Vector{T}, layers) where {T} + @assert length(vec) == length(layers) + "The length of the downsample `Vector` must match the number of stages" if T <: Symbol downs = [] for i in vec @@ -282,6 +288,13 @@ function _make_downsample_fns(vec::Vector{T}) where {T} end end +function _make_downsample_fns(sym::Symbol, layers) + @assert sym in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + return collect(shortcut_dict[sym] for _ in 1:length(layers)) +end +_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function @@ -345,13 +358,14 @@ function _drop_blocks(drop_block_prob = 0.0) end """ - resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, - stem = first(resnet_stem(; inchannels)), inplanes = 64, - downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), - drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), - use_conv = false)) + resnet(block_fn, layers, downsample_opt = :B; + inchannels = 3, nclasses = 1000, output_stride = 32, + stem = first(resnet_stem(; inchannels)), inplanes = 64, + block_args::NamedTuple = NamedTuple(), + drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = 0.0), + classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), + use_conv = false)) This function creates the layers for many ResNet-like models. See the user guide for more information. @@ -360,7 +374,7 @@ information. If you are an end-user trying to use ResNet-like models, you should consider [`ResNet`](#) and similar higher-level functions instead. This version is significantly more customisable - at the cost of being more complicated. + at the cost of being more significantly more complicated. # Arguments @@ -371,6 +385,25 @@ information. - `layers`: A list of integers specifying the number of blocks in each stage. For example, `[3, 4, 6, 3]` would mean that the network would have 4 stages, with 3, 4, 6 and 3 blocks in each. + - `downsample_opt`: Downsampling options. This can be any one of the following: + + + A single `Symbol` specifying the downsample option to use for all stages. The default + is :B, which corresponds to a 1x1 convolution-based downsample for every stage except + the first, which uses an identity projection. The other options are `:A`, which uses + an identity projection for all stages, `:C`, which uses a convolution-based + downsample for all stages and `:D`, which uses a max-pooling-based downsample for every + stage except the first, which uses an identity projection. `:A`, `:B` and `:C` are + are described in the [paper](https://arxiv.org/abs/1512.03385), while `:D` is + described in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) paper. + + A `Vector` of `Symbol`s specifying the downsample options to use for each stage. The + choices are the same as the single option above. The length of this `Vector` must be + the same as the length of `layers`. + + A `Vector` of `NTuple{2}`s specifying the downsample functions to use for each stage. + The functions have to be passed in directly here - see [`downsample_identity`](#), + [`downsample_conv`](#), and [`downsample_pool`](#). The first element of each tuple is + the downsample function to use for the first stage, and the second element is the + function to use for the rest of the stages. The length of this `Vector` must be the + same as the length of `layers`. - `nclasses`: The number of output classes. - `inchannels`: The number of input channels. - `output_stride`: The total stride of the network i.e. the amount by which the input is @@ -381,10 +414,6 @@ information. function for creating a stem (see [`resnet_stem`](#)) but you can also create your own (although this is not usually necessary). - `inplanes`: The number of output channels from the stem. - - `downsample_type`: The type of downsampling to use. Either `:conv` or `:pool`. The former - uses a traditional convolution-based downsampling, while the latter is an - average-pooling-based downsampling that was suggested in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) - paper. - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed to the block function. For more information regarding valid arguments, see the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). @@ -395,12 +424,13 @@ information. + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). - `classifier_args`: A `NamedTuple` that **must** specify the following arguments: - + `pool_layer`: The adaptive pooling layer to use in the classifier head. + + `pool_layer`: The pooling layer to use in the classifier head. Pass this in with the + arguments to the layer defined. For example, if you want to use an adaptive mean pooling + layer, you would pass in `AdaptiveMeanPool((1, 1))`. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ -function resnet(block_fn, layers, - downsample_list::Vector = collect(:B for _ in 1:length(layers)); +function resnet(block_fn, layers, downsample_opt = :B; inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, block_args::NamedTuple = NamedTuple(), @@ -410,7 +440,7 @@ function resnet(block_fn, layers, use_conv = false)) ## Feature Blocks channels = collect(64 * 2^i for i in range(0, length(layers))) - downsample_fns = _make_downsample_fns(downsample_list) + downsample_fns = _make_downsample_fns(downsample_opt, layers) stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fns, drop_rates, block_args) ## Classifier head From 92ed4fa3dee8397059696660bc3b17883955dd42 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 12 Jul 2022 11:52:46 +0530 Subject: [PATCH 23/64] Add classifier and backbone methods --- src/convnets/resnets/resnet.jl | 3 +++ src/convnets/resnets/resnext.jl | 3 +++ src/convnets/resnets/seresnet.jl | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 3356ef225..ffbee32dc 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -39,3 +39,6 @@ function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 100 end return ResNet(layers) end + +backbone(m::ResNet) = m.layers[1] +classifier(m::ResNet) = m.layers[2] diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 1fa00a7b0..dc1de9464 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -38,3 +38,6 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, end return ResNeXt(layers) end + +backbone(m::ResNeXt) = m.layers[1] +classifier(m::ResNeXt) = m.layers[2] diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 605c074d6..bc57a08fc 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -35,6 +35,9 @@ function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1 return SEResNet(layers) end +backbone(m::SEResNet) = m.layers[1] +classifier(m::SEResNet) = m.layers[2] + """ SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) @@ -75,3 +78,6 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt end return SEResNeXt(layers) end + +backbone(m::SEResNeXt) = m.layers[1] +classifier(m::SEResNeXt) = m.layers[2] From 96a7d31a80cb4bdb9a0d462c72b6c4735aa50947 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 17 Jul 2022 09:30:32 +0530 Subject: [PATCH 24/64] Refactor of resnet core --- src/convnets/convnext.jl | 2 +- src/convnets/efficientnet.jl | 3 +- src/convnets/resnets/core.jl | 531 ++++++++++--------------------- src/convnets/resnets/resnext.jl | 3 +- src/convnets/resnets/seresnet.jl | 7 +- src/layers/Layers.jl | 2 +- src/layers/drop.jl | 10 + src/other/mlpmixer.jl | 2 +- 8 files changed, 193 insertions(+), 367 deletions(-) diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 6ced7eeb9..113a35142 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -51,7 +51,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 push!(downsample_layers, downsample_layer) end stages = [] - dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths)) + dp_rates = droppath_rates(drop_path_rate; depth = sum(depths)) cur = 0 for i in eachindex(depths) push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]]) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index da9000468..02c5b6eb6 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -79,8 +79,7 @@ const efficientnet_block_configs = [ # w: width scaling # d: depth scaling # r: image resolution -const efficientnet_global_configs = Dict( - # (r, (w, d)) +const efficientnet_global_configs = Dict(# (r, (w, d)) :b0 => (224, (1.0, 1.0)), :b1 => (240, (1.0, 1.1)), :b2 => (260, (1.1, 1.2)), diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index e55eb7a76..4b76a3d23 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,122 +1,170 @@ ## It is recommended to check out the user guide for more information. -### ResNet blocks -## These functions return a block to be used inside of a ResNet model. -## The individual arguments are explained in the documentation of the functions. -## Note that for these blocks to be used by the `_make_blocks` function, they must define -## a dispatch `expansion(::typeof(fn))` that returns the expansion factor of the block -## (i.e. the multiplicative factor by which the number of channels in the input is increased). -## The `_make_blocks` function will then call the `expansion` function to determine the -## expansion factor of each block and use this to construct the stages of the model. +abstract type AbstractResNetBlock end -""" - basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, dilation = 1, first_dilation = dilation, - activation = relu, connection = addact\$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) +struct basicblock <: AbstractResNetBlock + inplanes::Integer + planes::Integer + reduction_factor::Integer +end +function basicblock(inplanes, planes, reduction_factor, base_width, cardinality) + @assert base_width == 64 "`base_width` must be 64 for `basicblock`" + @assert cardinality == 1 "`cardinality` must be 1 for `basicblock`" + return basicblock(inplanes, planes, reduction_factor) +end +expansion_factor(::basicblock) = 1 + +struct bottleneck <: AbstractResNetBlock + inplanes::Integer + planes::Integer + reduction_factor::Integer + base_width::Integer + cardinality::Integer +end +expansion_factor(::bottleneck) = 4 -Creates a basic ResNet block. +# Downsample layer using convolutions. +function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm) + return Chain(Conv((1, 1), inplanes => outplanes; stride, pad = SamePad(), bias = false), + norm_layer(outplanes)) +end -# Arguments +# Downsample layer using max pooling +function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm) + pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) + return Chain(pool, + Conv((1, 1), inplanes => outplanes; bias = false), + norm_layer(outplanes)) +end - - `inplanes`: number of input feature maps - - `planes`: number of feature maps for the block - - `stride`: the stride of the block - - `downsample`: the downsampling function to use - - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first - convolution. - - `dilation`: the dilation of the second convolution. - - `first_dilation`: the dilation of the first convolution. - - `activation`: the activation function to use. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - - `norm_layer`: the normalization layer to use. - - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. - - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the - attention function. -""" -function basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, dilation = 1, first_dilation = dilation, - activation = relu, connection = addact$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) - expansion = expansion_factor(basicblock) - first_planes = planes ÷ reduction_factor - outplanes = planes * expansion - conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, - dilation = first_dilation, bias = false), - norm_layer(first_planes)) - drop_block = drop_block - conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = dilation, - dilation = dilation, bias = false), - norm_layer(outplanes)) - attn_layer = attn_fn(outplanes; attn_args...) - return Parallel(connection, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, - drop_path)) +# Downsample layer which is an identity projection. Uses max pooling +# when the output size is more than the input size. +function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) + if outplanes > inplanes + return Chain(MaxPool((1, 1); stride = 2), + y -> cat_channels(y, + zeros(eltype(y), + size(y, 1), + size(y, 2), + outplanes - inplanes, size(y, 4)))) + else + return identity + end end -expansion_factor(::typeof(basicblock)) = 1 -""" - bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, connection = addact\$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) +function downsample_block(downsample_fns, inplanes, planes, expansion; stride = 1, + norm_layer = BatchNorm) + down_fn1, down_fn2 = downsample_fns + if stride != 1 || inplanes != planes * expansion + return down_fn1(inplanes, planes * expansion; stride, norm_layer) + else + return down_fn2(inplanes, planes * expansion; stride, norm_layer) + end +end + +# Shortcut configurations for the ResNet models +const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) -Creates a bottleneck ResNet block. +# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is +# specified as a `Vector` of `Symbol`s. This is used to make the downsample +# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is +# already an `NTuple{2}` of functions, it is returned unchanged. +function _make_downsample_fns(vec::Vector{<:Symbol}, layers) + downs = [] + for i in vec + @assert i in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + push!(downs, shortcut_dict[i]) + end + return downs +end +function _make_downsample_fns(sym::Symbol, layers) + @assert sym in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + return collect(shortcut_dict[sym] for _ in 1:length(layers)) +end +_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec +_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) -# Arguments +# Stride for each block in the ResNet model +function get_stride(::AbstractResNetBlock, idxs::NTuple{2, Integer}) + return (idxs[1] == 1 || idxs[1] == 1) ? 2 : 1 +end - - `inplanes`: number of input feature maps - - `planes`: number of feature maps for the block - - `stride`: the stride of the block - - `downsample`: the downsampling function to use - - `cardinality`: the number of groups in the convolution. - - `base_width`: the number of output feature maps for each convolutional group. - - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first - convolution. - - `first_dilation`: the dilation of the 3x3 convolution. - - `activation`: the activation function to use. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - - `norm_layer`: the normalization layer to use. - - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. - - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the - attention function. -""" -function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, connection = addact$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) - expansion = expansion_factor(bottleneck) - width = floor(Int, planes * (base_width / 64)) * cardinality - first_planes = width ÷ reduction_factor - outplanes = planes * expansion - conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), - norm_layer(first_planes, activation)) - conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, - dilation = first_dilation, groups = cardinality, bias = false), +# returns `DropBlock`s for each stage of the ResNet +function _drop_blocks(drop_block_rate::AbstractFloat) + return [ + identity, identity, + DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00) + ] +end + +function _make_layers(block::basicblock, norm_layer, stride) + first_planes = block.planes ÷ block.reduction_factor + outplanes = block.planes * expansion_factor(block) + conv_bn1 = Chain(Conv((3, 3), block.inplanes => first_planes; stride, pad = 1, bias = false), + norm_layer(first_planes)) + conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = 1, bias = false), + norm_layer(outplanes)) + layers = [] + push!(layers, conv_bn1, conv_bn2) + return layers +end + +function _make_layers(block::bottleneck, norm_layer, stride) + width = fld(block.planes * block.base_width, 64) * block.cardinality + first_planes = width ÷ block.reduction_factor + outplanes = block.planes * expansion_factor(block) + conv_bn1 = Chain(Conv((1, 1), block.inplanes => first_planes; bias = false), + norm_layer(first_planes)) + conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = 1, + groups = block.cardinality, bias = false), norm_layer(width)) conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) - attn_layer = attn_fn(outplanes; attn_args...) - return Parallel(connection, downsample, - Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, - attn_layer, drop_path)) + layers = [] + push!(layers, conv_bn1, conv_bn2, conv_bn3) + return layers +end + +function make_block(block::T, idxs::NTuple{2, Integer}; kwargs...) where {T <: AbstractResNetBlock} + stage_idx, block_idx = idxs + kwargs = Dict(kwargs) + stride = get(kwargs, :stride_fn, get_stride)(block, idxs) + expansion = expansion_factor(block) + norm_layer = get(kwargs, :norm_layer, BatchNorm) + layers = _make_layers(block, norm_layer, stride) + activation = get(kwargs, :activation, relu) + insert!(layers, 2, activation) + if T <: bottleneck + insert!(layers, 4, activation) + end + if haskey(kwargs, :drop_block_rate) + layer_idx = T <: basicblock ? 2 : 3 + dropblock = _drop_blocks(kwargs[:drop_block_rate])[stage_idx] + insert!(layers, layer_idx, dropblock) + end + if haskey(kwargs, :attn_fn) + attn_layer = kwargs[:attn_fn](block.planes) + push!(layers, attn_layer) + end + if haskey(kwargs, :drop_path_rate) + droppath = DropPath(kwargs[:droppath_rates][block_idx]) + push!(layers, droppath) + end + if haskey(kwargs, :downsample_fns) + downsample_tup = kwargs[:downsample_fns][stage_idx] + downsample = downsample_block(downsample_tup, block.inplanes, block.planes, expansion; stride) + connection = get(kwargs, :connection, addact)$activation + return Parallel(connection, downsample, Chain(layers...)) + else + return Chain(layers...) + end end -expansion_factor(::typeof(bottleneck)) = 4 """ resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, @@ -145,8 +193,8 @@ on how to use this function. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ -function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, - norm_layer = BatchNorm, activation = relu) +function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, + replace_stem_pool::Bool = false, norm_layer = BatchNorm, activation = relu) @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem @@ -181,272 +229,43 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = return Chain(conv1, bn1, stempool), inplanes end -### Downsampling layers -## These will almost never be used directly. They are used by the `_make_blocks` function to -## build the downsampling layers. In most cases, these defaults will not need to be changed. -## If you wish to write your own ResNet model using the `_make_blocks` function, you can use -## this function to build the downsampling layers. - -# Downsample layer using convolutions. -function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size - dilation = kernel_size[1] > 1 ? dilation : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, - dilation, bias = false), - norm_layer(outplanes)) -end - -# Downsample layer using max pooling -function downsample_pool(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = MeanPool((2, 2); stride = avg_stride, pad) - end - return Chain(pool, - Conv((1, 1), inplanes => outplanes; bias = false), - norm_layer(outplanes)) -end - -# Downsample layer which is an identity projection. Uses max pooling -# when the output size is more than the input size. -function downsample_identity(kernel_size, inplanes, outplanes; kwargs...) - if outplanes > inplanes - return Chain(MaxPool((1, 1); stride = 2), - y -> cat_channels(y, - zeros(eltype(y), - size(y, 1), - size(y, 2), - outplanes - inplanes, size(y, 4)))) - else - return identity - end -end - -""" - downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), - stride = 1, dilation = 1, norm_layer = BatchNorm) - -Wrapper function that makes it easier to build a downsample block inside a ResNet model. -This function is almost never used directly or customised by the user. - -# Arguments - - - `downsample_fn`: The function to use for downsampling in skip connections. Recommended usage - is passing in either `downsample_conv` or `downsample_pool`. - - `inplanes`: The number of input feature maps. - - `planes`: The number of output feature maps. - - `expansion`: The expansion factor of the block. - - `kernel_size`: The size of the convolutional kernel. - - `stride`: The stride of the convolutional layer. - - `dilation`: The dilation of the convolutional layer. - - `norm_layer`: The normalisation layer to be used. -""" -function downsample_block(downsample_fns, inplanes, planes, expansion; - kernel_size = (1, 1), stride = 1, dilation = 1, - norm_layer = BatchNorm) - down_fn1, down_fn2 = downsample_fns - if stride != 1 || inplanes != planes * expansion - return down_fn1(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) - else - return down_fn2(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) - end -end - -# Shortcut configurations for the ResNet models -const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), - :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv), - :D => (downsample_pool, downsample_identity)) - -# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is -# specified as a `Vector` of `Symbol`s. This is used to make the downsample -# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is -# already an `NTuple{2}` of functions, it is returned unchanged. -function _make_downsample_fns(vec::Vector{T}, layers) where {T} - @assert length(vec) == length(layers) - "The length of the downsample `Vector` must match the number of stages" - if T <: Symbol - downs = [] - for i in vec - @assert i in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - push!(downs, shortcut_dict[i]) - end - return downs - elseif T <: NTuple{2} - return vec - else - throw(ArgumentError("The shortcut list must be a `Vector` of `Symbol`s or `NTuple{2}`s")) - end -end - -function _make_downsample_fns(sym::Symbol, layers) - @assert sym in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - return collect(shortcut_dict[sym] for _ in 1:length(layers)) -end -_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) - # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, - downsample_fns::Vector, drop_rates::NamedTuple, - block_args::NamedTuple) - @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" - expansion = expansion_factor(block_fn) +function resnet_stages(block_type, channels, block_repeats, inplanes; kwargs...) stages = [] - net_block_idx = 1 - net_stride = 4 - dilation = prev_dilation = 1 - # Stochastic depth linear decay rule (DropPath) - dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), - sum(block_repeats)) - # DropBlock rate - dbr = get(drop_rates, :drop_block_rate, 0) + kwargs = Dict(kwargs) + cardinality = get(kwargs, :cardinality, 1) + base_width = get(kwargs, :base_width, 64) + reduction_factor = get(kwargs, :reduction_factor, 1) ## Construct each stage - for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr), - downsample_fns)) - # Number of planes in each stage, number of blocks in each stage, and the drop block rate - planes, num_blocks, drop_block, down_fns = itr - # Stride calculations for each stage - stride = stage_idx == 1 ? 1 : 2 - if net_stride >= output_stride - dilation *= stride - stride = 1 - else - net_stride *= stride - end - # Downsample block; either a (default) convolution-based block or a pooling-based block - downsample = downsample_block(down_fns, inplanes, planes, expansion; - stride, dilation) + for (stage_idx, (planes, num_blocks)) in enumerate(zip(channels, block_repeats)) ## Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks - # Different behaviour for the first block of each stage - downsample = block_idx == 1 ? downsample : identity - stride = block_idx == 1 ? stride : 1 - push!(blocks, - block_fn(inplanes, planes; stride, downsample, - first_dilation = prev_dilation, - drop_path = DropPath(dp_rates[block_idx]), drop_block, - block_args...)) - prev_dilation = dilation - inplanes = planes * expansion - net_block_idx += 1 + block_struct = block_type(inplanes, planes, reduction_factor, base_width, cardinality) + block = make_block(block_struct, (stage_idx, block_idx); kwargs...) + inplanes = planes * expansion_factor(block_struct) + push!(blocks, block) end push!(stages, Chain(blocks...)) end - return Chain(stages...) -end - -# returns `DropBlock`s for each stage of the ResNet -function _drop_blocks(drop_block_prob = 0.0) - return [ - identity, identity, - DropBlock(drop_block_prob, 5, 0.25), DropBlock(drop_block_prob, 3, 1.00), - ] + return Chain(stages...), inplanes end -""" - resnet(block_fn, layers, downsample_opt = :B; - inchannels = 3, nclasses = 1000, output_stride = 32, - stem = first(resnet_stem(; inchannels)), inplanes = 64, - block_args::NamedTuple = NamedTuple(), - drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), - use_conv = false)) - -This function creates the layers for many ResNet-like models. See the user guide for more -information. - -!!! note - - If you are an end-user trying to use ResNet-like models, you should consider [`ResNet`](#) - and similar higher-level functions instead. This version is significantly more customisable - at the cost of being more significantly more complicated. - -# Arguments - - - `block_fn`: The type of block to use inside the ResNet model. Must be either `:basicblock`, - which is the standard ResNet block, or `:bottleneck`, which is the ResNet block with a - bottleneck structure. See the [paper](https://arxiv.org/abs/1512.03385) for more details. - - - `layers`: A list of integers specifying the number of blocks in each stage. For example, - `[3, 4, 6, 3]` would mean that the network would have 4 stages, with 3, 4, 6 and 3 blocks in - each. - - `downsample_opt`: Downsampling options. This can be any one of the following: - - + A single `Symbol` specifying the downsample option to use for all stages. The default - is :B, which corresponds to a 1x1 convolution-based downsample for every stage except - the first, which uses an identity projection. The other options are `:A`, which uses - an identity projection for all stages, `:C`, which uses a convolution-based - downsample for all stages and `:D`, which uses a max-pooling-based downsample for every - stage except the first, which uses an identity projection. `:A`, `:B` and `:C` are - are described in the [paper](https://arxiv.org/abs/1512.03385), while `:D` is - described in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) paper. - + A `Vector` of `Symbol`s specifying the downsample options to use for each stage. The - choices are the same as the single option above. The length of this `Vector` must be - the same as the length of `layers`. - + A `Vector` of `NTuple{2}`s specifying the downsample functions to use for each stage. - The functions have to be passed in directly here - see [`downsample_identity`](#), - [`downsample_conv`](#), and [`downsample_pool`](#). The first element of each tuple is - the downsample function to use for the first stage, and the second element is the - function to use for the rest of the stages. The length of this `Vector` must be the - same as the length of `layers`. - - `nclasses`: The number of output classes. - - `inchannels`: The number of input channels. - - `output_stride`: The total stride of the network i.e. the amount by which the input is - downsampled throughout the network. This is used to determine the output size from the - backbone of the network. Must be one of `[8, 16, 32]`. - - `stem`: A constructed ResNet stem, passed in to be used in the model. `inplanes` should be - set to the number of output channels from this stem. Metalhead provides an in-built - function for creating a stem (see [`resnet_stem`](#)) but you can also create your - own (although this is not usually necessary). - - `inplanes`: The number of output channels from the stem. - - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed - to the block function. For more information regarding valid arguments, see - the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). - - `drop_rates`: A `NamedTuple` that may define none, some or all of the following: - - + `dropout_rate`: The rate of dropout to be used in the classifier head. - + `drop_path_rate`: Stochastic depth implemented using [`DropPath`](#). - + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). - - `classifier_args`: A `NamedTuple` that **must** specify the following arguments: - - + `pool_layer`: The pooling layer to use in the classifier head. Pass this in with the - arguments to the layer defined. For example, if you want to use an adaptive mean pooling - layer, you would pass in `AdaptiveMeanPool((1, 1))`. - + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a - `Dense` layer. -""" -function resnet(block_fn, layers, downsample_opt = :B; - inchannels = 3, nclasses = 1000, output_stride = 32, - stem = first(resnet_stem(; inchannels)), inplanes = 64, - block_args::NamedTuple = NamedTuple(), - drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), - use_conv = false)) +function resnet(block_fn, layers, downsample_opt = :B; inchannels::Integer = 3, + nclasses::Integer = 1000, stem = first(resnet_stem(; inchannels)), + inplanes::Integer = 64, kwargs...) + kwargs = Dict(kwargs) ## Feature Blocks channels = collect(64 * 2^i for i in range(0, length(layers))) downsample_fns = _make_downsample_fns(downsample_opt, layers) - stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; - output_stride, downsample_fns, drop_rates, block_args) + stage_blocks, num_features = resnet_stages(block_fn, channels, layers, inplanes; downsample_fns, kwargs...) ## Classifier head - expansion = expansion_factor(block_fn) - num_features = 512 * expansion - pool_layer, use_conv = classifier_args + # num_features = 512 * expansion_factor(block_fn) + pool_layer = get(kwargs, :pool_layer, AdaptiveMeanPool((1, 1))) + use_conv = get(kwargs, :use_conv, false) # Pooling if pool_layer === identity @assert use_conv @@ -456,7 +275,7 @@ function resnet(block_fn, layers, downsample_opt = :B; global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer # Fully-connected layer fc = create_fc(num_features, nclasses; use_conv) - classifier = Chain(global_pool, Dropout(get(drop_rates, :dropout_rate, 0)), fc) + classifier = Chain(global_pool, Dropout(get(kwargs, :dropout_rate, 0)), fc) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index dc1de9464..cee2e4757 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -31,8 +31,7 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width)) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index bc57a08fc..1de2f2195 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -27,8 +27,7 @@ end function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; attn_fn = squeeze_excite)) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, attn_fn = _ -> squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNet", depth)) end @@ -71,8 +70,8 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width, + attn_fn = _ -> squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index e0b870fe9..4269411ef 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -29,7 +29,7 @@ include("conv.jl") export conv_bn, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection include("drop.jl") -export DropPath, DropBlock +export DropBlock, DropPath, droppath_rates include("selayers.jl") export squeeze_excite, effective_squeeze_excite diff --git a/src/layers/drop.jl b/src/layers/drop.jl index dc6cb3c54..981a05efc 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -146,3 +146,13 @@ equivalent to `identity`. on the CPU. """ DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity + +""" + droppath_rates(drop_path_rate::AbstractFloat = 0.0; depth) + +Returns the drop path rates for a given depth using the linear scaling rule +((reference)[https://arxiv.org/abs/1603.09382]) +""" +function droppath_rates(drop_path_rate::AbstractFloat = 0.0; depth) + return LinRange{Float32}(0.0, drop_path_rate, depth) +end \ No newline at end of file diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 48f1efd8c..ec74c79d6 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -59,7 +59,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, embedplanes = 512, drop_path_rate = 0.0, depth = 12, nclasses = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) - dp_rates = LinRange{Float32}(0.0, drop_path_rate, depth) + dp_rates = droppath_rates(drop_path_rate; depth) layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], kwargs...) From 954029910ac38497e85c70b782c129cbb066f7f3 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 16 Jun 2022 18:16:13 +0530 Subject: [PATCH 25/64] Add `DropBlock` --- Project.toml | 2 + src/convnets/densenet.jl | 2 +- src/layers/Layers.jl | 8 ++-- src/layers/drop.jl | 61 ++++++++++++++++++++++++++++ src/layers/{mlp.jl => mlp-linear.jl} | 15 +++++++ src/layers/others.jl | 26 ------------ 6 files changed, 84 insertions(+), 30 deletions(-) create mode 100644 src/layers/drop.jl rename src/layers/{mlp.jl => mlp-linear.jl} (83%) delete mode 100644 src/layers/others.jl diff --git a/Project.toml b/Project.toml index b585973c9..42f12a887 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,14 @@ version = "0.7.3" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 0b318dbf3..588d2ad22 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -100,7 +100,7 @@ Create a DenseNet model - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ -function densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000) +function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, nclasses = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; reduction = reduction, nclasses = nclasses) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 1034136f3..8b2b73059 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,8 +1,10 @@ module Layers using Flux -using Flux: outputsize, Zygote +using NNlib +using NNlibCUDA using Functors +using ChainRulesCore using Statistics using MLUtils @@ -10,10 +12,10 @@ include("../utilities.jl") include("attention.jl") include("embeddings.jl") -include("mlp.jl") +include("mlp-linear.jl") include("normalise.jl") include("conv.jl") -include("others.jl") +include("drop.jl") export MHAttention, PatchEmbedding, ViPosEmbedding, ClassTokens, diff --git a/src/layers/drop.jl b/src/layers/drop.jl new file mode 100644 index 000000000..93c120651 --- /dev/null +++ b/src/layers/drop.jl @@ -0,0 +1,61 @@ +""" + DropBlock(drop_prob = 0.1, block_size = 7) + +Implements DropBlock, a regularization method for convolutional networks. +([reference](https://arxiv.org/pdf/1810.12890.pdf)) +""" +struct DropBlock{F} + drop_prob::F + block_size::Integer +end +@functor DropBlock + +(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size) + +DropBlock(drop_prob = 0.1, block_size = 7) = DropBlock(drop_prob, block_size) + +function _dropblock_checks(x, drop_prob, T) + if !(T <: AbstractArray) + throw(ArgumentError("x must be an `AbstractArray`")) + end + if ndims(x) != 4 + throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) + end + @assert drop_prob < 0 || drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" +end +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, T) + +function dropblock(x::T, drop_prob, block_size::Integer) where {T} + _dropblock_checks(x, drop_prob, T) + if drop_prob == 0 + return x + end + return _dropblock(x, drop_prob, block_size) +end + +function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size) where {T} + gamma = drop_prob / (block_size ^ 2) + mask = rand_like(x, Float32, (size(x, 1), size(x, 2), size(x, 3))) + mask .<= gamma + block_mask = maxpool(reshape(mask, (size(mask)[1:3]..., 1)), (block_size, block_size); + pad = block_size ÷ 2, stride = (1, 1)) + if block_size % 2 == 0 + block_mask = block_mask[1:(end - 1), 1:(end - 1), :, :] + end + block_mask = 1 .- dropdims(block_mask; dims = 4) + out = (x .* reshape(block_mask, (size(block_mask)[1:3]..., 1))) * length(block_mask) / + sum(block_mask) + return out +end + +""" + DropPath(p) + +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. +([reference](https://arxiv.org/abs/1603.09382)) + +# Arguments + + - `p`: rate of Stochastic Depth. +""" +DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity diff --git a/src/layers/mlp.jl b/src/layers/mlp-linear.jl similarity index 83% rename from src/layers/mlp.jl rename to src/layers/mlp-linear.jl index 25ead874b..e282e2632 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp-linear.jl @@ -1,3 +1,18 @@ +""" + LayerScale(λ, planes::Integer) + +Creates a `Flux.Scale` layer that performs "`LayerScale`" +([reference](https://arxiv.org/abs/2103.17239)). + +# Arguments + + - `planes`: Size of channel dimension in the input. + - `λ`: initialisation value for the learnable diagonal matrix. +""" +function LayerScale(planes::Integer, λ) + return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity +end + """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; dropout = 0., activation = gelu) diff --git a/src/layers/others.jl b/src/layers/others.jl deleted file mode 100644 index 770bccebd..000000000 --- a/src/layers/others.jl +++ /dev/null @@ -1,26 +0,0 @@ -""" - LayerScale(λ, planes::Integer) - -Creates a `Flux.Scale` layer that performs "`LayerScale`" -([reference](https://arxiv.org/abs/2103.17239)). - -# Arguments - - - `planes`: Size of channel dimension in the input. - - `λ`: initialisation value for the learnable diagonal matrix. -""" -function LayerScale(planes::Integer, λ) - return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity -end - -""" - DropPath(p) - -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. -([reference](https://arxiv.org/abs/1603.09382)) - -# Arguments - - - `p`: rate of Stochastic Depth. -""" -DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity From 588d70325e7142348ba95a20735a6725f11f6da7 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 21 Jun 2022 15:21:12 +0530 Subject: [PATCH 26/64] Initial commit for new ResNet API --- docs/make.jl | 4 +- docs/serve.jl | 2 +- src/Metalhead.jl | 4 +- src/convnets/densenet.jl | 3 +- src/convnets/resnet.jl | 408 ++++++++++++++++----------------------- src/layers/Layers.jl | 3 +- src/layers/drop.jl | 45 +++-- src/layers/normalise.jl | 4 +- 8 files changed, 202 insertions(+), 271 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index db03f1d76..f5d29f7e9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Pkg -Pkg.develop(path = "..") +Pkg.develop(; path = "..") using Publish using Artifacts, LazyArtifacts @@ -13,5 +13,5 @@ p = Publish.Project(Metalhead) function build_and_deploy(label) rm(label; recursive = true, force = true) - deploy(Metalhead; root = "/Metalhead.jl", label = label) + return deploy(Metalhead; root = "/Metalhead.jl", label = label) end diff --git a/docs/serve.jl b/docs/serve.jl index 763e77e93..bf4a51179 100644 --- a/docs/serve.jl +++ b/docs/serve.jl @@ -1,6 +1,6 @@ using Pkg -Pkg.develop(path = "..") +Pkg.develop(; path = "..") using Revise using Publish diff --git a/src/Metalhead.jl b/src/Metalhead.jl index f391c0c66..9a60ad351 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -40,7 +40,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, +# ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -49,7 +49,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, +for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, # :ResNet, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 588d2ad22..374909bb1 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -100,7 +100,8 @@ Create a DenseNet model - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ -function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, nclasses = 1000) +function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, + nclasses = 1000) where {N} return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; reduction = reduction, nclasses = nclasses) end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index 53d1fd6e3..768697131 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -1,259 +1,185 @@ -""" - basicblock(inplanes, outplanes, downsample = false) - -Create a basic residual block -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input -""" -function basicblock(inplanes, outplanes, downsample = false) - stride = downsample ? 2 : 1 - return Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, - bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, - bias = false)...) +function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, + reduce_first = 1, dilation = 1, first_dilation = nothing, + act_layer = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = 1 + @assert cardinality==1 "BasicBlock only supports cardinality of 1" + @assert base_width==64 "BasicBlock does not support changing base width" + first_planes = planes ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, + dilation = first_dilation, bias = false), + norm_layer(first_planes)) + drop_block = drop_block === identity ? identity : drop_block + conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; stride, pad = dilation, + dilation = dilation, bias = false), + norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_path)), + act_layer) end -""" - bottleneck(inplanes, outplanes, downsample = false; stride = [1, (downsample ? 2 : 1), 1]) - -Create a bottleneck residual block -([reference](https://arxiv.org/abs/1512.03385v1)). The bottleneck is composed of -3 convolutional layers each with the given `stride`. -By default, `stride` implements ["ResNet v1.5"](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch) -which uses `stride == [1, 2, 1]` when `downsample == true`. -This version is standard across various ML frameworks. -The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead. - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input - - `stride`: a list of the stride of the 3 convolutional layers -""" -function bottleneck(inplanes, outplanes, downsample = false; - stride = [1, (downsample ? 2 : 1), 1]) - return Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], - bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, - bias = false)..., - conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], - bias = false)...) +function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, + reduce_first = 1, dilation = 1, first_dilation = nothing, + act_layer = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = 4 + width = floor(Int, planes * (base_width / 64)) * cardinality + first_planes = width ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), + norm_layer(first_planes)) + conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, + dilation = first_dilation, groups = cardinality, bias = false), + norm_layer(width)) + drop_block = drop_block === identity ? identity : drop_block() + conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_block, + act_layer, conv_bn3, drop_path)), + act_layer) end -""" - bottleneck_v1(inplanes, outplanes, downsample = false) - -Create a bottleneck residual block -([reference](https://arxiv.org/abs/1512.03385v1)). The bottleneck is composed of -3 convolutional layers with all a stride of 1 except the first convolutional -layer which has a stride of 2. - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input -""" -function bottleneck_v1(inplanes, outplanes, downsample = false) - return bottleneck(inplanes, outplanes, downsample; - stride = [(downsample ? 2 : 1), 1, 1]) +function drop_blocks(drop_prob = 0.0) + return [identity, identity, + drop_prob == 0.0 ? DropBlock(drop_prob, 5, 0.25) : identity, + drop_prob == 0.0 ? DropBlock(drop_prob, 3, 1.00) : identity] end -""" - resnet(block, residuals::NTuple{2, Any}, connection = addrelu; - channel_config, block_config, nclasses = 1000) - -Create a ResNet model -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - `residuals`: a 2-tuple of functions with input `(inplanes, outplanes, downsample=false)`, - each of which will return a function that will be used as a new "skip" path to match a residual block. - [`Metalhead.skip_identity`](#) and [`Metalhead.skip_projection`](#) can be used here. - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection = addrelu; - channel_config, block_config, nclasses = 1000) - inplanes = 64 - baseplanes = 64 - layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false)) - push!(layers, MaxPool((3, 3); stride = (2, 2), pad = (1, 1))) - for (i, nrepeats) in enumerate(block_config) - # output planes within a block - outplanes = baseplanes .* channel_config - # push first skip connection on using first residual - # downsample the residual path if this is the first repetition of a block - push!(layers, - Parallel(connection, block(inplanes, outplanes, i != 1), - residuals[i][1](inplanes, outplanes[end], i != 1))) - # push remaining skip connections on using second residual - inplanes = outplanes[end] - for _ in 2:nrepeats - push!(layers, - Parallel(connection, block(inplanes, outplanes, false), - residuals[i][2](inplanes, outplanes[end], false))) - inplanes = outplanes[end] - end - # next set of output plane base is doubled - baseplanes *= 2 - end - # next set of output plane base is doubled - baseplanes *= 2 - return Chain(Chain(layers), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(inplanes, nclasses))) +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) end -""" - resnet(block, shortcut_config::Symbol, connection = addrelu; - channel_config, block_config, nclasses = 1000) - -Create a ResNet model -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - - `shortcut_config`: the type of shortcut style (either `:A`, `:B`, or `:C`) - - + `:A`: uses a [`Metalhead.skip_identity`](#) for all residual blocks - + `:B`: uses a [`Metalhead.skip_projection`](#) for the first residual block - and [`Metalhead.skip_identity`](@) for the remaining residual blocks - + `:C`: uses a [`Metalhead.skip_projection`](#) for all residual blocks - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnet(block, shortcut_config::AbstractVector{<:Symbol}, args...; kwargs...) - shortcut_dict = Dict(:A => (skip_identity, skip_identity), - :B => (skip_projection, skip_identity), - :C => (skip_projection, skip_projection)) - if any(sc -> !haskey(shortcut_dict, sc), shortcut_config) - error("Unrecognized shortcut_config ($shortcut_config) passed to `resnet` (use only :A, :B, or :C).") +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = avg_pool_fn((2, 2); stride = avg_stride, pad) end - shortcut = [shortcut_dict[sc] for sc in shortcut_config] - return resnet(block, shortcut, args...; kwargs...) -end - -function resnet(block, shortcut_config::Symbol, args...; block_config, kwargs...) - return resnet(block, fill(shortcut_config, length(block_config)), args...; - block_config = block_config, kwargs...) -end - -function resnet(block, residuals::NTuple{2}, args...; kwargs...) - return resnet(block, [residuals], args...; kwargs...) -end - -const resnet_config = Dict(18 => (([1, 1], [2, 2, 2, 2], [:A, :B, :B, :B]), basicblock), - 34 => (([1, 1], [3, 4, 6, 3], [:A, :B, :B, :B]), basicblock), - 50 => (([1, 1, 4], [3, 4, 6, 3], [:B, :B, :B, :B]), bottleneck), - 101 => (([1, 1, 4], [3, 4, 23, 3], [:B, :B, :B, :B]), bottleneck), - 152 => (([1, 1, 4], [3, 8, 36, 3], [:B, :B, :B, :B]), bottleneck)) - -""" - ResNet(channel_config, block_config, shortcut_config; - block, connection = addrelu, nclasses = 1000) -Create a `ResNet` model -([reference](https://arxiv.org/abs/1512.03385v1)). -See also [`resnet`](#). - -# Arguments - - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `shortcut_config`: the type of shortcut style (either `:A`, `:B`, or `:C`). - `shortcut_config` can also be a vector of symbols if different shortcut styles are applied to - different residual blocks. - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `nclasses`: the number of output classes -""" -struct ResNet - layers::Any + return Chain(pool, + Conv((1, 1), in_channels => out_channels; stride = 1, pad = 0, + bias = false), + norm_layer(out_channels)) end -function ResNet(channel_config, block_config, shortcut_config; - block, connection = addrelu, nclasses = 1000) - layers = resnet(block, - shortcut_config, - connection; - channel_config = channel_config, - block_config = block_config, - nclasses = nclasses) - return ResNet(layers) +function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, + reduce_first = 1, output_stride = 32, + down_kernel_size = 1, avg_down = false, drop_block_rate = 0.0, + drop_path_rate = 0.0, kwargs...) + kwarg_dict = Dict(kwargs...) + stages = [] + net_block_idx = 1 + net_stride = 4 + dilation = prev_dilation = 1 + for (stage_idx, (planes, num_blocks, db)) in enumerate(zip(channels, block_repeats, + drop_blocks(drop_block_rate))) + stride = stage_idx == 1 ? 1 : 2 + if net_stride >= output_stride + dilation *= stride + stride = 1 + else + net_stride *= stride + end + downsample = identity + if stride != 1 || inplanes != planes * expansion + downsample = avg_down ? + downsample_avg(down_kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation = prev_dilation, + norm_layer = kwarg_dict[:norm_layer]) : + downsample_conv(down_kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation = prev_dilation, + norm_layer = kwarg_dict[:norm_layer]) + end + block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, + :drop_block => db, kwargs...) + blocks = [] + for block_idx in 1:num_blocks + downsample = block_idx == 1 ? downsample : identity + stride = block_idx == 1 ? stride : 1 + # stochastic depth linear decay rule + block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + push!(blocks, + block_fn(inplanes, planes; stride, downsample, + first_dilation = prev_dilation, + drop_path = DropPath(block_dpr), block_kwargs...)) + prev_dilation = dilation + inplanes = planes * expansion + net_block_idx += 1 + end + push!(stages, Chain(blocks...)) + end + return Chain(stages...) end -@functor ResNet - -(m::ResNet)(x) = m.layers(x) - -backbone(m::ResNet) = m.layers[1] -classifier(m::ResNet) = m.layers[2] - -""" - ResNet(depth = 50; pretrain = false, nclasses = 1000) - -Create a ResNet model with a specified depth -([reference](https://arxiv.org/abs/1512.03385v1)) -following [these modification](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch) -referred as ResNet v1.5. - -See also [`Metalhead.resnet`](#). - -# Arguments - - - `depth`: depth of the ResNet model. Options include (18, 34, 50, 101, 152). - - `nclasses`: the number of output classes - -For `ResNet(18)` and `ResNet(34)`, the parameter-free shortcut style (type `:A`) -is used in the first block and the three other blocks use type `:B` connection -(following the implementation in PyTorch). The published version of -`ResNet(18)` and `ResNet(34)` used type `:A` shortcuts for all four blocks. The -example below shows how to create a 18 or 34-layer `ResNet` using only type `:A` -shortcuts: - -```julia -using Metalhead - -resnet18 = ResNet([1, 1], [2, 2, 2, 2], :A; block = Metalhead.basicblock) +function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, + expansion = 1, + cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, + replace_stem_pool = false, reduce_first = 1, + down_kernel_size = (1, 1), avg_down = false, act_layer = relu, + norm_layer = BatchNorm, + drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, + block_kwargs...) + @assert output_stride in (8, 16, 32) + @assert stem_type in [:default, :deep, :deep_tiered] + # Stem + inplanes = stem_type == :deep ? stem_width * 2 : 64 + if stem_type == :deep + stem_channels = (stem_width, stem_width) + if stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + bias = false), + norm_layer(stem_channels[1]), + act_layer(), + Conv((3, 3), stem_channels[1] => stem_channels[1]; stride = 1, + pad = 1, bias = false), + norm_layer(stem_channels[2]), + act_layer(), + Conv((3, 3), stem_channels[2] => inplanes; stride = 1, pad = 1, + bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes) + act1 = act_layer + # Stem pooling + if replace_stem_pool + stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, + bias = false), + norm_layer(inplanes), + act_layer) + else + stempool = MaxPool((3, 3); stride = 2, pad = 1) + end + stem = Chain(conv1, bn1, act1, stempool) -resnet34 = ResNet([1, 1], [3, 4, 6, 3], :A; block = Metalhead.basicblock) -``` + # Feature Blocks + channels = [64, 128, 256, 512] + stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, + output_stride, reduce_first, avg_down, + down_kernel_size, act_layer, norm_layer, + drop_block_rate, drop_path_rate, block_kwargs...) -The bottleneck of the orginal ResNet model has a stride of 2 on the first -convolutional layer when downsampling (instead of the second convolutional layers -as in ResNet v1.5). The architecture of the orignal ResNet model can be obtained -as shown below: + # Head (Pooling and Classifier) + num_features = 512 * expansion + classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, + Dense(num_features, num_classes)) -```julia -resnet50_v1 = ResNet([1, 1, 4], [3, 4, 6, 3], :B; block = Metalhead.bottleneck_v1) -``` -""" -function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000) - @assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))" - config, block = resnet_config[depth] - model = ResNet(config...; block = block, nclasses = nclasses) - pretrain && loadpretrain!(model, string("resnet", depth)) - return model + return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 8b2b73059..6c417c077 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -24,5 +24,6 @@ export MHAttention, ChannelLayerNorm, prenorm, skip_identity, skip_projection, conv_bn, depthwise_sep_conv_bn, - invertedresidual, squeeze_excite + invertedresidual, squeeze_excite, + DropBlock end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 93c120651..b3f9a8719 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -7,45 +7,48 @@ Implements DropBlock, a regularization method for convolutional networks. struct DropBlock{F} drop_prob::F block_size::Integer + gamma_scale::F end @functor DropBlock -(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size) +(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size, m.gamma_scale) -DropBlock(drop_prob = 0.1, block_size = 7) = DropBlock(drop_prob, block_size) +function DropBlock(drop_prob = 0.1, block_size = 7, gamma_scale = 1.0) + return DropBlock(drop_prob, block_size, gamma_scale) +end -function _dropblock_checks(x, drop_prob, T) +function _dropblock_checks(x, drop_prob, gamma_scale, T) if !(T <: AbstractArray) throw(ArgumentError("x must be an `AbstractArray`")) end if ndims(x) != 4 throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) end - @assert drop_prob < 0 || drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" + @assert drop_prob < 0||drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" + @assert gamma_scale < 0||gamma_scale > 1 "gamma_scale must be between 0 and 1, got $gamma_scale" end -ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, T) +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, gamma_scale, T) -function dropblock(x::T, drop_prob, block_size::Integer) where {T} - _dropblock_checks(x, drop_prob, T) +function dropblock(x::T, drop_prob, block_size::Integer, gamma_scale) where {T} + _dropblock_checks(x, drop_prob, gamma_scale, T) if drop_prob == 0 return x end - return _dropblock(x, drop_prob, block_size) + return _dropblock(x, drop_prob, block_size, gamma_scale) end -function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size) where {T} - gamma = drop_prob / (block_size ^ 2) - mask = rand_like(x, Float32, (size(x, 1), size(x, 2), size(x, 3))) - mask .<= gamma - block_mask = maxpool(reshape(mask, (size(mask)[1:3]..., 1)), (block_size, block_size); - pad = block_size ÷ 2, stride = (1, 1)) - if block_size % 2 == 0 - block_mask = block_mask[1:(end - 1), 1:(end - 1), :, :] - end - block_mask = 1 .- dropdims(block_mask; dims = 4) - out = (x .* reshape(block_mask, (size(block_mask)[1:3]..., 1))) * length(block_mask) / - sum(block_mask) - return out +function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size, gamma_scale) where {T} + H, W, _, _ = size(x) + total_size = H * W + clipped_block_size = min(block_size, min(H, W)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size^2 / + ((W - block_size + 1) * (H - block_size + 1)) + block_mask = rand_like(x) .< gamma + block_mask = maxpool(convert(T, block_mask), (clipped_block_size, clipped_block_size); + stride = 1, padding = clipped_block_size ÷ 2) + block_mask = 1 .- block_mask + normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) + return x * block_mask * normalize_scale end """ diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 4f69dab03..2d5e6399a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -19,9 +19,9 @@ end @functor ChannelLayerNorm -(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) - function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) diag = Flux.Scale(1, 1, sz, λ) return ChannelLayerNorm(diag, ϵ) end + +(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) From 2a5d0cc9d94c914e402b0b0f9e68e3d9eceb5825 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 22 Jun 2022 07:33:48 +0530 Subject: [PATCH 27/64] Cleanup --- src/convnets/inception.jl | 59 +++++++-------- src/convnets/resnet.jl | 146 ++++++++++++++++++-------------------- src/convnets/vgg.jl | 36 ++++------ src/layers/attention.jl | 16 ++--- src/layers/drop.jl | 5 +- src/layers/mlp-linear.jl | 20 +++--- src/other/mlpmixer.jl | 29 ++++---- src/vit-based/vit.jl | 29 ++++---- 8 files changed, 169 insertions(+), 171 deletions(-) diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index c3fd39f5e..e4106e957 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -279,7 +279,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) + inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -287,10 +287,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -313,12 +313,13 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) inceptionv4_c(), inceptionv4_c(), inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(1536, nclasses)) return Chain(body, head) end """ - Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -326,8 +327,8 @@ Creates an Inceptionv4 model. # Arguments - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -338,7 +339,7 @@ struct Inceptionv4 layers::Any end -function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) +function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionv4(; inchannels, dropout, nclasses) pretrain && loadpretrain!(layers, "Inceptionv4") return Inceptionv4(layers) @@ -419,18 +420,18 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) # Arguments - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -446,12 +447,13 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), conv_bn((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(1536, nclasses)) return Chain(body, head) end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -459,8 +461,8 @@ Creates an InceptionResNetv2 model. # Arguments - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -471,9 +473,9 @@ struct InceptionResNetv2 layers::Any end -function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, +function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionresnetv2(; inchannels, dropout, nclasses) + layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "InceptionResNetv2") return InceptionResNetv2(layers) end @@ -533,18 +535,18 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) + xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., conv_bn((3, 3), 32, 64; bias = false)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), @@ -554,7 +556,8 @@ function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) xception_block(728, 1024, 2; stride = 2, grow_at_start = false), depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(2048, nclasses)) return Chain(body, head) end @@ -563,7 +566,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -571,16 +574,16 @@ Creates an Xception model. # Arguments - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. + - `inchannels`: number of input channels. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning `Xception` does not currently support pretrained weights. """ -function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - layers = xception(; inchannels, dropout, nclasses) +function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "xception") return Xception(layers) end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index 768697131..875421360 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -1,9 +1,42 @@ +function drop_blocks(drop_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_prob, 5, 0.25), + DropBlock(drop_prob, 3, 1.00), + ] +end + +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) +end + +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = avg_pool_fn((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), in_channels => out_channels; bias = false), + norm_layer(out_channels)) +end + function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, - reduce_first = 1, dilation = 1, first_dilation = nothing, - act_layer = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity) - expansion = 1 + expansion = expansion_factor(basicblock) @assert cardinality==1 "BasicBlock only supports cardinality of 1" @assert base_width==64 "BasicBlock does not support changing base width" first_planes = planes ÷ reduce_first @@ -17,16 +50,16 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina dilation = dilation, bias = false), norm_layer(outplanes)) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_path)), - act_layer) + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + activation) end +expansion_factor(::typeof(basicblock)) = 1 function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, - reduce_first = 1, dilation = 1, first_dilation = nothing, - act_layer = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity) - expansion = 4 + expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first outplanes = planes * expansion @@ -39,55 +72,25 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina drop_block = drop_block === identity ? identity : drop_block() conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_block, - act_layer, conv_bn3, drop_path)), - act_layer) -end - -function drop_blocks(drop_prob = 0.0) - return [identity, identity, - drop_prob == 0.0 ? DropBlock(drop_prob, 5, 0.25) : identity, - drop_prob == 0.0 ? DropBlock(drop_prob, 3, 1.00) : identity] + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_block, + activation, conv_bn3, drop_path)), + activation) end +expansion_factor(::typeof(bottleneck)) = 4 -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(out_channels)) -end - -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = avg_pool_fn((2, 2); stride = avg_stride, pad) - end - - return Chain(pool, - Conv((1, 1), in_channels => out_channels; stride = 1, pad = 0, - bias = false), - norm_layer(out_channels)) -end - -function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, - reduce_first = 1, output_stride = 32, - down_kernel_size = 1, avg_down = false, drop_block_rate = 0.0, - drop_path_rate = 0.0, kwargs...) +function make_blocks(block_fn, channels, block_repeats, inplanes; + reduce_first = 1, output_stride = 32, down_kernel_size = 1, + avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, + kwargs...) + expansion = expansion_factor(block_fn) kwarg_dict = Dict(kwargs...) stages = [] net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 - for (stage_idx, (planes, num_blocks, db)) in enumerate(zip(channels, block_repeats, - drop_blocks(drop_block_rate))) + for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, + block_repeats, + drop_blocks(drop_block_rate))) stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride dilation *= stride @@ -95,6 +98,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, else net_stride *= stride end + # first block needs to be handled differently for downsampling downsample = identity if stride != 1 || inplanes != planes * expansion downsample = avg_down ? @@ -106,7 +110,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, norm_layer = kwarg_dict[:norm_layer]) end block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => db, kwargs...) + :drop_block => drop_block, kwargs...) blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity @@ -127,15 +131,13 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, end function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, - expansion = 1, cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, - down_kernel_size = (1, 1), avg_down = false, act_layer = relu, - norm_layer = BatchNorm, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, block_kwargs...) - @assert output_stride in (8, 16, 32) - @assert stem_type in [:default, :deep, :deep_tiered] + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Stem inplanes = stem_type == :deep ? stem_width * 2 : 64 if stem_type == :deep @@ -145,38 +147,32 @@ function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride end conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, bias = false), - norm_layer(stem_channels[1]), - act_layer(), - Conv((3, 3), stem_channels[1] => stem_channels[1]; stride = 1, - pad = 1, bias = false), - norm_layer(stem_channels[2]), - act_layer(), - Conv((3, 3), stem_channels[2] => inplanes; stride = 1, pad = 1, - bias = false)) + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) end - bn1 = norm_layer(inplanes) - act1 = act_layer + bn1 = norm_layer(inplanes, activation) # Stem pooling if replace_stem_pool stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, bias = false), - norm_layer(inplanes), - act_layer) + norm_layer(inplanes, activation)) else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - stem = Chain(conv1, bn1, act1, stempool) - + stem = Chain(conv1, bn1, stempool) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, output_stride, reduce_first, avg_down, - down_kernel_size, act_layer, norm_layer, + down_kernel_size, activation, norm_layer, drop_block_rate, drop_path_rate, block_kwargs...) - # Head (Pooling and Classifier) + expansion = expansion_factor(block) num_features = 512 * expansion classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, Dense(num_features, num_classes)) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 56975a124..15560de7c 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -52,7 +52,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -63,19 +63,19 @@ Create VGG classifier (fully connected) layers the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `dropout`: the dropout level between each fully connected layer + - `drop_rate`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +function vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) return Chain(MLUtils.flatten, Dense(Int(prod(imsize)), fcsize, relu), - Dropout(dropout), + Dropout(drop_rate), Dense(fcsize, fcsize, relu), - Dropout(dropout), + Dropout(drop_rate), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -90,12 +90,12 @@ Create a VGG model - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `drop_rate`: dropout level between fully connected layers """ -function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) +function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + class = vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) return Chain(Chain(conv), class) end @@ -114,7 +114,7 @@ struct VGG end """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -126,17 +126,11 @@ Construct a VGG model with the specified input image size. Typically, the image - `nclasses`::Integer : number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `drop_rate`: dropout level between fully connected layers """ -function VGG(imsize::Dims{2}; - config, inchannels, batchnorm = false, nclasses, fcsize, dropout) - layers = vgg(imsize; config = config, - inchannels = inchannels, - batchnorm = batchnorm, - nclasses = nclasses, - fcsize = fcsize, - dropout = dropout) - +function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, + drop_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, drop_rate) return VGG(layers) end @@ -165,7 +159,7 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses batchnorm = batchnorm, nclasses = nclasses, fcsize = 4096, - dropout = 0.5) + drop_rate = 0.5) if pretrain && !batchnorm loadpretrain!(model, string("vgg", depth)) elseif pretrain diff --git a/src/layers/attention.jl b/src/layers/attention.jl index a1244a033..b6e7b7678 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -7,18 +7,18 @@ Multi-head self-attention layer. - `nheads`: Number of heads - `qkv_layer`: layer to be used for getting the query, key and value - - `attn_drop`: dropout rate after the self-attention layer + - `attn_drop_rate`: dropout rate after the self-attention layer - `projection`: projection layer to be used after self-attention """ struct MHAttention{P, Q, R} nheads::Int qkv_layer::P - attn_drop::Q + attn_drop_rate::Q projection::R end """ - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop = 0., proj_drop = 0.) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.) Multi-head self-attention layer. @@ -27,15 +27,15 @@ Multi-head self-attention layer. - `planes`: number of input channels - `nheads`: number of heads - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_drop`: dropout rate after the self-attention layer - - `proj_drop`: dropout rate after the projection layer + - `attn_drop_rate`: dropout rate after the self-attention layer + - `proj_drop_rate`: dropout rate after the projection layer """ function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_drop = 0.0, proj_drop = 0.0) + attn_drop_rate = 0.0, proj_drop_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_drop) - proj = Chain(Dense(planes, planes), Dropout(proj_drop)) + attn_drop = Dropout(attn_drop_rate) + proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) return MHAttention(nheads, qkv_layer, attn_drop, proj) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b3f9a8719..8e6202085 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -54,11 +54,12 @@ end """ DropPath(p) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0 and +`identity` otherwise. ([reference](https://arxiv.org/abs/1603.09382)) # Arguments - `p`: rate of Stochastic Depth. """ -DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity +DropPath(p) = p > 0 ? Dropout(p; dims = 4) : identity diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp-linear.jl index e282e2632..550c2ad22 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp-linear.jl @@ -15,7 +15,7 @@ end """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0., activation = gelu) + drop_rate =0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -24,18 +24,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `drop_rate`: Dropout rate. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), - Dense(hidden_planes, outplanes), Dropout(dropout)) + drop_rate = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(drop_rate), + Dense(hidden_planes, outplanes), Dropout(drop_rate)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0., activation = gelu) + outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -46,16 +46,16 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `drop_rate`: Dropout rate. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0.0, activation = gelu) + outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(dropout), + Dropout(drop_rate), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(dropout)) + Dropout(drop_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 942abc823..ed4c47af3 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0., drop_path_rate = 0., activation = gelu) + drop_rate =0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -12,20 +12,22 @@ Creates a feedforward block for the MLPMixer architecture. - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP and/or the channel mixing MLP as a ratio to the number of planes in the block. - `mlp_layer`: the MLP layer to use in the block - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu) + drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] return Chain(SkipConnection(Chain(LayerNorm(planes), swapdims((2, 1, 3)), - mlp_layer(npatches, tokenplanes; activation, dropout), + mlp_layer(npatches, tokenplanes; activation, + drop_rate), swapdims((2, 1, 3)), DropPath(drop_path_rate)), +), SkipConnection(Chain(LayerNorm(planes), - mlp_layer(planes, channelplanes; activation, dropout), + mlp_layer(planes, channelplanes; activation, + drop_rate), DropPath(drop_path_rate)), +)) end @@ -113,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; dropout = 0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; drop_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. @@ -126,13 +128,13 @@ Creates a block for the ResMixer architecture. - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `mlp_layer`: the MLP block to use - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks - `λ`: initialisation constant for the LayerScale """ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) + drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) return Chain(SkipConnection(Chain(Flux.Scale(planes), swapdims((2, 1, 3)), Dense(npatches, npatches), @@ -140,7 +142,7 @@ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, LayerScale(planes, λ), DropPath(drop_path_rate)), +), SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); dropout, + mlp_layer(planes, Int(mlp_ratio * planes); drop_rate, activation), LayerScale(planes, λ), DropPath(drop_path_rate)), +)) @@ -230,7 +232,7 @@ end """ spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, dropout = 0.0, drop_path_rate = 0., + norm_layer = LayerNorm, drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. @@ -243,18 +245,19 @@ Creates a feedforward block based on the gMLP model architecture described in th - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `norm_layer`: the normalisation layer to use - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, dropout = 0.0, + mlp_layer = gated_mlp_block, drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) channelplanes = Int(mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), - mlp_layer(sgu, planes, channelplanes; activation, dropout), + mlp_layer(sgu, planes, channelplanes; activation, + drop_rate), DropPath(drop_path_rate)), +) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 012bfef9d..686ddc4d5 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate =0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,23 +10,24 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `dropout`: dropout rate + - `drop_rate`: dropout rate """ -function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.0) +function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, - MHAttention(planes, nheads; attn_drop = dropout, - proj_drop = dropout)), +), + MHAttention(planes, nheads; + attn_drop_rate = drop_rate, + proj_drop_rate = drop_rate)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - dropout)), +)) + drop_rate)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + emb_drop_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -40,22 +41,23 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `dropout`: dropout rate + - `drop_rate`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), ViPosEmbedding(embedplanes, npatches + 1), - Dropout(emb_dropout), - transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout), + Dropout(emb_drop_rate), + transformer_encoder(embedplanes, depth, nheads; mlp_ratio, + drop_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end @@ -98,7 +100,6 @@ function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" kwargs = vit_configs[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) - return ViT(layers) end From 07c1e9596562321d7eb759404533e2c8ca665fdb Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 23 Jun 2022 11:49:46 +0530 Subject: [PATCH 28/64] Get some stuff to work 1. Some docs 2. Basic tests for ResNet and ResNeXt now pass --- src/Metalhead.jl | 7 +- src/convnets/resne(x)t.jl | 377 ++++++++++++++++++++++++++++++++++++++ src/convnets/resnet.jl | 181 ------------------ src/convnets/resnext.jl | 126 ------------- src/layers/drop.jl | 43 +++-- test/convnets.jl | 19 +- 6 files changed, 407 insertions(+), 346 deletions(-) create mode 100644 src/convnets/resne(x)t.jl delete mode 100644 src/convnets/resnet.jl delete mode 100644 src/convnets/resnext.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 9a60ad351..5463c64de 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,8 +22,7 @@ include("convnets/alexnet.jl") include("convnets/vgg.jl") include("convnets/inception.jl") include("convnets/googlenet.jl") -include("convnets/resnet.jl") -include("convnets/resnext.jl") +include("convnets/resne(x)t.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") @@ -40,7 +39,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, -# ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -49,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, # :ResNet, +for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, :ResNet, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl new file mode 100644 index 000000000..abf7193b1 --- /dev/null +++ b/src/convnets/resne(x)t.jl @@ -0,0 +1,377 @@ +# returns `DropBlock`s for each block of the ResNet +function _drop_blocks(drop_block_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_block_prob, 5, 0.25), + DropBlock(drop_block_prob, 3, 1.00), + ] +end + +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) +end + +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = MeanPool((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), in_channels => out_channels; bias = false), + norm_layer(out_channels)) +end + +function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = expansion_factor(basicblock) + @assert cardinality==1 "`basicblock` only supports cardinality of 1" + @assert base_width==64 "`basicblock` does not support changing base width" + first_planes = planes ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, + dilation = first_dilation, bias = false), + norm_layer(first_planes)) + drop_block = drop_block + conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = dilation, + dilation = dilation, bias = false), + norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + activation) +end +expansion_factor(::typeof(basicblock)) = 1 + +function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity) + expansion = expansion_factor(bottleneck) + width = floor(Int, planes * (base_width / 64)) * cardinality + first_planes = width ÷ reduce_first + outplanes = planes * expansion + first_dilation = !isnothing(first_dilation) ? first_dilation : dilation + conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), + norm_layer(first_planes, activation)) + conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, + dilation = first_dilation, groups = cardinality, bias = false), + norm_layer(width)) + conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) + return Chain(Parallel(+, downsample, + Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, + drop_path)), + activation) +end +expansion_factor(::typeof(bottleneck)) = 4 + +# Makes the main stages of the ResNet model. This is an internal function and should not be +# used by end-users. `block_fn` is a function that returns a single block of the ResNet. +# See `basicblock` and `bottleneck` for examples. A block must define a function +# `expansion(::typeof(block))` that returns the expansion factor of the block. +function _make_blocks(block_fn, channels, block_repeats, inplanes; + reduce_first = 1, output_stride = 32, down_kernel_size = (1, 1), + avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, + kwargs...) + expansion = expansion_factor(block_fn) + kwarg_dict = Dict(kwargs...) + stages = [] + net_block_idx = 1 + net_stride = 4 + dilation = prev_dilation = 1 + for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, + block_repeats, + _drop_blocks(drop_block_rate))) + # Stride calculations for each stage + stride = stage_idx == 1 ? 1 : 2 + if net_stride >= output_stride + dilation *= stride + stride = 1 + else + net_stride *= stride + end + # use average pooling for projection skip connection between stages/downsample. + downsample = identity + if stride != 1 || inplanes != planes * expansion + downsample_fn = avg_down ? downsample_avg : downsample_conv + downsample = downsample_fn(down_kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation = dilation, + norm_layer = kwarg_dict[:norm_layer]) + end + # arguments to be passed into the block function + block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, + :drop_block => drop_block, kwargs...) + blocks = [] + for block_idx in 1:num_blocks + downsample = block_idx == 1 ? downsample : identity + stride = block_idx == 1 ? stride : 1 + # stochastic depth linear decay rule + block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + push!(blocks, + block_fn(inplanes, planes; stride, downsample, + first_dilation = prev_dilation, + drop_path = DropPath(block_dpr), block_kwargs...)) + prev_dilation = dilation + inplanes = planes * expansion + net_block_idx += 1 + end + push!(stages, Chain(blocks...)) + end + return Chain(stages...) +end + +""" + resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, + cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, + drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, + block_kwargs...) + +Creates the layers of a ResNe(X)t model. If you are an end-user, you should probably use +[ResNet](@ref) instead and pass in the parameters you want to modify as optional parameters +there. + +# Arguments: + + - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for + example. + + - `layers`: A list of integers representing the number of blocks in each stage. + - `nclasses`: The number of output classes. The default value is 1000. + - `inchannels`: The number of input channels to the model. The default value is 3. + - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + This is used for [ResNeXt](@ref)-like models. The default value is 1. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels: `planes * base_width / 64 * cardinality`. + The default value is 64. + - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. + - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: + + + `:default` - a single 7x7 convolution layer with a width of `stem_width` + + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` + + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. + The default value is `:default`. + - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a + convolution layer. The default value is false. + - `reduce_first`: Reduction factor for first convolution output width of residual blocks, + Default is 1 for all architectures except SE-Nets, where it is 2. + - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the + skip connection. The default value is (1, 1) for all architectures except + SE-Nets, where it is (3, 3). + - `avg_down`: Use average pooling for projection skip connection between stages/downsample. + - `activation`: The activation function to use. The default value is `relu`. + - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. + - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `DropPath`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. + - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. + +If you are an end-user trying to tweak the ResNet model, note that there is no guarantee that +all combinations of parameters will work. In particular, tweaking `block_kwargs` is not +advised unless you know what you are doing. +""" +function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, + cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, + drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, + block_kwargs...) + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + # Stem + inplanes = stem_type == :deep ? stem_width * 2 : 64 + if stem_type == :deep + stem_channels = (stem_width, stem_width) + if stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + bias = false), + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes, activation) + # Stem pooling + if replace_stem_pool + stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, + bias = false), + norm_layer(inplanes, activation)) + else + stempool = MaxPool((3, 3); stride = 2, pad = 1) + end + stem = Chain(conv1, bn1, stempool) + # Feature Blocks + channels = [64, 128, 256, 512] + stage_blocks = _make_blocks(block, channels, layers, inplanes; cardinality, base_width, + output_stride, reduce_first, avg_down, + down_kernel_size, activation, norm_layer, + drop_block_rate, drop_path_rate, block_kwargs...) + # Head (Pooling and Classifier) + expansion = expansion_factor(block) + num_features = 512 * expansion + classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, + Dense(num_features, nclasses)) + return Chain(Chain(stem, stage_blocks), classifier) +end + +const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) + +""" + ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) + +Creates a ResNet model. +((reference)[https://arxiv.org/abs/1512.03385]) + +# Arguments: + + - `depth`: The depth of the `ResNet` model. Must be one of `[18, 34, 50, 101, 152]`. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. + - `nclasses`: The number of output classes. The default value is 1000. + +Apart from these, the model can also take any of the following optional arguments: + + - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for + example. + + - `layers`: A list of integers representing the number of blocks in each stage. + - `inchannels`: The number of input channels to the model. The default value is 3. + - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + This is used for [ResNeXt](@ref)-like models. The default value is 1. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels: `planes * base_width / 64 * cardinality`. + The default value is 64. + - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. + - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: + + + `:default` - a single 7x7 convolution layer with a width of `stem_width` + + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` + + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. + The default value is `:default`. + - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a + convolution layer. The default value is false. + - `reduce_first`: Reduction factor for first convolution output width of residual blocks, + Default is 1 for all architectures except SE-Nets, where it is 2. + - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the + skip connection. The default value is (1, 1) for all architectures except + SE-Nets, where it is (3, 3). + - `avg_down`: Use average pooling for projection skip connection between stages/downsample. + - `activation`: The activation function to use. The default value is `relu`. + - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. + - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. + - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. + +See also [`resnet`](@ref) for more details. + +!!! warning + + Pretrained models are not supported for all parameter combinations of `ResNet`. +""" +struct ResNet + layers::Any +end +@functor ResNet + +function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) + @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + model = resnet(resnet_config[depth]...; nclasses, kwargs...) + pretrain && loadpretrain!(model, string("resnet", depth)) + return model +end + +""" + ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, nclasses = 1000, + kwargs...) + +Creates a ResNeXt model. +((reference)[https://arxiv.org/abs/1611.05431]) + +# Arguments: + + - `depth`: The depth of the `ResNeXt` model. Must be one of `[50, 101, 152]`. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + of the `ResNeXt` mode. The default value is 4. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels. The default value is 32. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. + - `nclasses`: The number of output classes. The default value is 1000. + +Apart from these, the model can also take any of the following optional arguments: + + - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for + example. + + - `layers`: A list of integers representing the number of blocks in each stage. + - `inchannels`: The number of input channels to the model. The default value is 3. + - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. + - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. + This is used for [ResNeXt](@ref)-like models. The default value is 1. + - `base_width`: The base width of each bottleneck block. It is the factor determining + the number of bottleneck channels: `planes * base_width / 64 * cardinality`. + The default value is 64. + - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. + - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: + + + `:default` - a single 7x7 convolution layer with a width of `stem_width` + + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` + + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. + The default value is `:default`. + - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a + convolution layer. The default value is false. + - `reduce_first`: Reduction factor for first convolution output width of residual blocks, + Default is 1 for all architectures except SE-Nets, where it is 2. + - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the + skip connection. The default value is (1, 1) for all architectures except + SE-Nets, where it is (3, 3). + - `avg_down`: Use average pooling for projection skip connection between stages/downsample. + - `activation`: The activation function to use. The default value is `relu`. + - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. + - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. + - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. + - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. + +See also [`resnet`](@ref) for more details. + +!!! warning + + Pretrained models are not currently supported for `ResNeXt`. +""" +struct ResNeXt + layers::Any +end +@functor ResNeXt + +function ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, + nclasses = 1000, + kwargs...) + @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" + model = resnet(resnet_config[depth]...; cardinality, base_width, nclasses, kwargs...) + pretrain && + loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + return model +end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl deleted file mode 100644 index 875421360..000000000 --- a/src/convnets/resnet.jl +++ /dev/null @@ -1,181 +0,0 @@ -function drop_blocks(drop_prob = 0.0) - return [ - identity, - identity, - DropBlock(drop_prob, 5, 0.25), - DropBlock(drop_prob, 3, 1.00), - ] -end - -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(out_channels)) -end - -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = avg_pool_fn((2, 2); stride = avg_stride, pad) - end - return Chain(pool, - Conv((1, 1), in_channels => out_channels; bias = false), - norm_layer(out_channels)) -end - -function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) - expansion = expansion_factor(basicblock) - @assert cardinality==1 "BasicBlock only supports cardinality of 1" - @assert base_width==64 "BasicBlock does not support changing base width" - first_planes = planes ÷ reduce_first - outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation - conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, - dilation = first_dilation, bias = false), - norm_layer(first_planes)) - drop_block = drop_block === identity ? identity : drop_block - conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; stride, pad = dilation, - dilation = dilation, bias = false), - norm_layer(outplanes)) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), - activation) -end -expansion_factor(::typeof(basicblock)) = 1 - -function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) - expansion = expansion_factor(bottleneck) - width = floor(Int, planes * (base_width / 64)) * cardinality - first_planes = width ÷ reduce_first - outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation - conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), - norm_layer(first_planes)) - conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, - dilation = first_dilation, groups = cardinality, bias = false), - norm_layer(width)) - drop_block = drop_block === identity ? identity : drop_block() - conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, drop_block, - activation, conv_bn3, drop_path)), - activation) -end -expansion_factor(::typeof(bottleneck)) = 4 - -function make_blocks(block_fn, channels, block_repeats, inplanes; - reduce_first = 1, output_stride = 32, down_kernel_size = 1, - avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - kwargs...) - expansion = expansion_factor(block_fn) - kwarg_dict = Dict(kwargs...) - stages = [] - net_block_idx = 1 - net_stride = 4 - dilation = prev_dilation = 1 - for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, - block_repeats, - drop_blocks(drop_block_rate))) - stride = stage_idx == 1 ? 1 : 2 - if net_stride >= output_stride - dilation *= stride - stride = 1 - else - net_stride *= stride - end - # first block needs to be handled differently for downsampling - downsample = identity - if stride != 1 || inplanes != planes * expansion - downsample = avg_down ? - downsample_avg(down_kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation = prev_dilation, - norm_layer = kwarg_dict[:norm_layer]) : - downsample_conv(down_kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation = prev_dilation, - norm_layer = kwarg_dict[:norm_layer]) - end - block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => drop_block, kwargs...) - blocks = [] - for block_idx in 1:num_blocks - downsample = block_idx == 1 ? downsample : identity - stride = block_idx == 1 ? stride : 1 - # stochastic depth linear decay rule - block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) - push!(blocks, - block_fn(inplanes, planes; stride, downsample, - first_dilation = prev_dilation, - drop_path = DropPath(block_dpr), block_kwargs...)) - prev_dilation = dilation - inplanes = planes * expansion - net_block_idx += 1 - end - push!(stages, Chain(blocks...)) - end - return Chain(stages...) -end - -function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, - cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), - avg_down = false, activation = relu, norm_layer = BatchNorm, - drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, - block_kwargs...) - @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" - @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" - # Stem - inplanes = stem_type == :deep ? stem_width * 2 : 64 - if stem_type == :deep - stem_channels = (stem_width, stem_width) - if stem_type == :deep_tiered - stem_channels = (3 * (stem_width ÷ 4), stem_width) - end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, - bias = false), - norm_layer(stem_channels[1], activation), - Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, - bias = false), - norm_layer(stem_channels[2], activation), - Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) - else - conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) - end - bn1 = norm_layer(inplanes, activation) - # Stem pooling - if replace_stem_pool - stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, - bias = false), - norm_layer(inplanes, activation)) - else - stempool = MaxPool((3, 3); stride = 2, pad = 1) - end - stem = Chain(conv1, bn1, stempool) - # Feature Blocks - channels = [64, 128, 256, 512] - stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, - output_stride, reduce_first, avg_down, - down_kernel_size, activation, norm_layer, - drop_block_rate, drop_path_rate, block_kwargs...) - # Head (Pooling and Classifier) - expansion = expansion_factor(block) - num_features = 512 * expansion - classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, - Dense(num_features, num_classes)) - - return Chain(Chain(stem, stage_blocks), classifier) -end diff --git a/src/convnets/resnext.jl b/src/convnets/resnext.jl deleted file mode 100644 index fc00bb180..000000000 --- a/src/convnets/resnext.jl +++ /dev/null @@ -1,126 +0,0 @@ -""" - resnextblock(inplanes, outplanes, cardinality, width, downsample = false) - -Create a basic residual block as defined in the paper for ResNeXt -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `downsample`: set to `true` to downsample the input -""" -function resnextblock(inplanes, outplanes, cardinality, width, downsample = false) - stride = downsample ? 2 : 1 - hidden_channels = cardinality * width - return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false)..., - conv_bn((3, 3), hidden_channels, hidden_channels; - stride = stride, pad = 1, bias = false, groups = cardinality)..., - conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false)...) -end - -""" - resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. relu(x) + relu(y); - block_config, nclasses = 1000) - -Create a ResNeXt model -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments - - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `widen_factor`: the factor by which the width of the bottleneck is increased after each stage - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnext(cardinality, width, widen_factor = 2, - connection = (x, y) -> @. relu(x) + relu(y); - block_config, nclasses = 1000) - inplanes = 64 - baseplanes = 128 - layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3))) - push!(layers, MaxPool((3, 3); stride = (2, 2), pad = (1, 1))) - for (i, nrepeats) in enumerate(block_config) - # output planes within a block - outplanes = baseplanes * widen_factor - # push first skip connection on using first residual - # downsample the residual path if this is the first repetition of a block - push!(layers, - Parallel(connection, - resnextblock(inplanes, outplanes, cardinality, width, i != 1), - skip_projection(inplanes, outplanes, i != 1))) - # push remaining skip connections on using second residual - inplanes = outplanes - for _ in 2:nrepeats - push!(layers, - Parallel(connection, - resnextblock(inplanes, outplanes, cardinality, width, false), - skip_identity(inplanes, outplanes, false))) - end - baseplanes = outplanes - # double width after every cluster of blocks - width *= widen_factor - end - return Chain(Chain(layers), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(inplanes, nclasses))) -end - -""" - ResNeXt(cardinality, width; block_config, nclasses = 1000) - -Create a ResNeXt model -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments - - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -struct ResNeXt - layers::Any -end - -function ResNeXt(cardinality, width; block_config, nclasses = 1000) - layers = resnext(cardinality, width; block_config, nclasses) - return ResNeXt(layers) -end - -@functor ResNeXt - -(m::ResNeXt)(x) = m.layers(x) - -backbone(m::ResNeXt) = m.layers[1] -classifier(m::ResNeXt) = m.layers[2] - -const resnext_config = Dict(50 => (3, 4, 6, 3), - 101 => (3, 4, 23, 3), - 152 => (3, 8, 36, 3)) - -""" - ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000) - -Create a ResNeXt model with specified configuration. Currently supported values for `config` are (50, 101). -([reference](https://arxiv.org/abs/1611.05431)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -!!! warning - - `ResNeXt` does not currently support pretrained weights. - -See also [`Metalhead.resnext`](#). -""" -function ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, - nclasses = 1000) - @assert config in keys(resnext_config) "`config` must be one of $(sort(collect(keys(resnext_config))))" - model = ResNeXt(cardinality, width; block_config = resnext_config[config], nclasses) - pretrain && loadpretrain!(model, string("ResNeXt", config)) - return model -end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 8e6202085..df66a3e7f 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,54 +1,53 @@ """ - DropBlock(drop_prob = 0.1, block_size = 7) + DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) Implements DropBlock, a regularization method for convolutional networks. ([reference](https://arxiv.org/pdf/1810.12890.pdf)) """ struct DropBlock{F} - drop_prob::F + drop_block_prob::F block_size::Integer gamma_scale::F end @functor DropBlock -(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size, m.gamma_scale) +(m::DropBlock)(x) = dropblock(x, m.drop_block_prob, m.block_size, m.gamma_scale) -function DropBlock(drop_prob = 0.1, block_size = 7, gamma_scale = 1.0) - return DropBlock(drop_prob, block_size, gamma_scale) +function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) + if drop_block_prob == 0.0 + return identity + end + @assert drop_block_prob < 0 || drop_block_prob > 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert gamma_scale < 0 || gamma_scale > 1 + "gamma_scale must be between 0 and 1, got $gamma_scale" + return DropBlock(drop_block_prob, block_size, gamma_scale) end -function _dropblock_checks(x, drop_prob, gamma_scale, T) +function _dropblock_checks(x::T) where {T} if !(T <: AbstractArray) throw(ArgumentError("x must be an `AbstractArray`")) end if ndims(x) != 4 throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) end - @assert drop_prob < 0||drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob" - @assert gamma_scale < 0||gamma_scale > 1 "gamma_scale must be between 0 and 1, got $gamma_scale" -end -ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, gamma_scale, T) - -function dropblock(x::T, drop_prob, block_size::Integer, gamma_scale) where {T} - _dropblock_checks(x, drop_prob, gamma_scale, T) - if drop_prob == 0 - return x - end - return _dropblock(x, drop_prob, block_size, gamma_scale) end +ChainRulesCore.@non_differentiable _dropblock_checks(x) -function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size, gamma_scale) where {T} +function dropblock(x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale) where {T} + _dropblock_checks(x) H, W, _, _ = size(x) total_size = H * W clipped_block_size = min(block_size, min(H, W)) - gamma = gamma_scale * drop_prob * total_size / clipped_block_size^2 / + gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / ((W - block_size + 1) * (H - block_size + 1)) block_mask = rand_like(x) .< gamma - block_mask = maxpool(convert(T, block_mask), (clipped_block_size, clipped_block_size); - stride = 1, padding = clipped_block_size ÷ 2) + block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) block_mask = 1 .- block_mask normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) - return x * block_mask * normalize_scale + return x .* block_mask .* normalize_scale end """ diff --git a/test/convnets.jl b/test/convnets.jl index 97cfd846e..11e472743 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -30,23 +30,16 @@ GC.gc() @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] m = ResNet(sz) @test size(m(x_256)) == (1000, 1) - if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) - else - @test_throws ArgumentError ResNet(sz, pretrain = true) - end + ## TODO: find a way to port pretrained models to the new ResNet API + # if (ResNet, sz) in PRETRAINED_MODELS + # @test acctest(ResNet(sz, pretrain = true)) + # else + # @test_throws ArgumentError ResNet(sz, pretrain = true) + # end @test gradtest(m, x_256) GC.safepoint() GC.gc() end - - @testset "Shortcut C" begin - m = Metalhead.resnet(Metalhead.basicblock, :C; - channel_config = [1, 1], - block_config = [2, 2, 2, 2]) - @test size(m(x_256)) == (1000, 1) - @test gradtest(m, x_256) - end end GC.safepoint() From 2e882014b503abd4222cdcb6d8fc7a83e10a3a23 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 23 Jun 2022 12:32:21 +0530 Subject: [PATCH 29/64] Tweaks - I --- src/convnets/inception.jl | 2 +- src/layers/attention.jl | 8 ++++---- src/layers/normalise.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index e4106e957..ead229551 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -340,7 +340,7 @@ struct Inceptionv4 end function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionv4(; inchannels, dropout, nclasses) + layers = inceptionv4(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "Inceptionv4") return Inceptionv4(layers) end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index b6e7b7678..3cefe7c0d 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,5 +1,5 @@ """ - MHAttention(nheads::Integer, qkv_layer, attn_drop, projection) + MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection) Multi-head self-attention layer. @@ -34,9 +34,9 @@ function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = fals attn_drop_rate = 0.0, proj_drop_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_drop_rate) + attn_drop_rate = Dropout(attn_drop_rate) proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) - return MHAttention(nheads, qkv_layer, attn_drop, proj) + return MHAttention(nheads, qkv_layer, attn_drop_rate, proj) end @functor MHAttention @@ -52,7 +52,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} seq_len * batch_size) query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) - attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) + attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) pre_projection = reshape(batched_mul(attention, value_reshaped), diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 2d5e6399a..c767bd1e0 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -24,4 +24,4 @@ function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) return ChannelLayerNorm(diag, ϵ) end -(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) +(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) From 01eaa8b1d8436fdf101291c49607d933076f1023 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 25 Jun 2022 15:42:22 +0530 Subject: [PATCH 30/64] Make pretrain condition explicit --- src/convnets/alexnet.jl | 4 +++- src/convnets/densenet.jl | 4 +++- src/convnets/googlenet.jl | 4 +++- src/convnets/inception.jl | 16 ++++++++++++---- src/convnets/mobilenet.jl | 11 +++++++++-- src/convnets/squeezenet.jl | 4 +++- 6 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 405272dd2..87f2c288e 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -49,7 +49,9 @@ end function AlexNet(; pretrain = false, nclasses = 1000) layers = alexnet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "AlexNet") + if pretrain + loadpretrain!(layers, "AlexNet") + end return AlexNet(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 374909bb1..9da4e08b2 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -162,6 +162,8 @@ See also [`Metalhead.densenet`](#). function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) @assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))." model = DenseNet(densenet_config[config]; nclasses = nclasses) - pretrain && loadpretrain!(model, string("DenseNet", config)) + if pretrain + loadpretrain!(model, string("DenseNet", config)) + end return model end diff --git a/src/convnets/googlenet.jl b/src/convnets/googlenet.jl index 318463494..946d0d7f7 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/googlenet.jl @@ -86,7 +86,9 @@ end function GoogLeNet(; pretrain = false, nclasses = 1000) layers = googlenet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "GoogLeNet") + if pretrain + loadpretrain!(layers, "GoogLeNet") + end return GoogLeNet(layers) end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index ead229551..ba30fa86f 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -182,7 +182,9 @@ end function Inceptionv3(; pretrain = false, nclasses = 1000) layers = inceptionv3(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "Inceptionv3") + if pretrain + loadpretrain!(layers, "Inceptionv3") + end return Inceptionv3(layers) end @@ -341,7 +343,9 @@ end function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionv4(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "Inceptionv4") + if pretrain + loadpretrain!(layers, "Inceptionv4") + end return Inceptionv4(layers) end @@ -476,7 +480,9 @@ end function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "InceptionResNetv2") + if pretrain + loadpretrain!(layers, "InceptionResNetv2") + end return InceptionResNetv2(layers) end @@ -584,7 +590,9 @@ Creates an Xception model. """ function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = xception(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "xception") + if pretrain + loadpretrain!(layers, "xception") + end return Xception(layers) end diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 93eba1c06..b7dfcd6f3 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -90,7 +90,9 @@ end function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv1")) + if pretrain + loadpretrain!(layers, string("MobileNetv1")) + end return MobileNetv1(layers) end @@ -189,6 +191,9 @@ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) pretrain && loadpretrain!(layers, string("MobileNetv2")) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end return MobileNetv2(layers) end @@ -319,7 +324,9 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = max_width = (mode == :large) ? 1280 : 1024 layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv3", mode)) + if pretrain + loadpretrain!(layers, string("MobileNetv3", mode)) + end return MobileNetv3(layers) end diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index c4de36acc..df458f9ff 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -68,7 +68,9 @@ end function SqueezeNet(; pretrain = false) layers = squeezenet() - pretrain && loadpretrain!(layers, "SqueezeNet") + if pretrain + loadpretrain!(layers, "SqueezeNet") + end return SqueezeNet(layers) end From 546b131980946fb3ca1e1eab666a96f49b01488d Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 28 Jun 2022 18:39:41 +0530 Subject: [PATCH 31/64] More declarative interface for ResNet 1. Less keywords for the user to worry about 2. Delete `ResNeXt` just for now --- src/convnets/efficientnet.jl | 111 ++++++------ src/convnets/mobilenet.jl | 2 +- src/convnets/resne(x)t.jl | 315 +++++++++-------------------------- 3 files changed, 134 insertions(+), 294 deletions(-) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 1465eb238..da9000468 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -6,19 +6,21 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). # Arguments -- `scalings`: global width and depth scaling (given as a tuple) -- `block_config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - `n`: number of block repetitions (will be scaled by global depth scaling) - - `k`: kernel size - - `s`: kernel stride - - `e`: expansion ratio - - `i`: block input channels (will be scaled by global width scaling) - - `o`: block output channels (will be scaled by global width scaling) -- `inchannels`: number of input channels -- `nclasses`: number of output classes -- `max_width`: maximum number of output channels before the fully connected - classification blocks + - `scalings`: global width and depth scaling (given as a tuple) + + - `block_config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `n`: number of block repetitions (will be scaled by global depth scaling) + + `k`: kernel size + + `s`: kernel stride + + `e`: expansion ratio + + `i`: block input channels (will be scaled by global width scaling) + + `o`: block output channels (will be scaled by global width scaling) + - `inchannels`: number of input channels + - `nclasses`: number of output classes + - `max_width`: maximum number of output channels before the fully connected + classification blocks """ function efficientnet(scalings, block_config; inchannels = 3, nclasses = 1000, max_width = 1280) @@ -64,34 +66,33 @@ end # i: block input channels # o: block output channels const efficientnet_block_configs = [ -# (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), + # (n, k, s, e, i, o) + (1, 3, 1, 1, 32, 16), + (2, 3, 2, 6, 16, 24), + (2, 5, 2, 6, 24, 40), + (3, 3, 2, 6, 40, 80), + (3, 5, 1, 6, 80, 112), (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320) + (1, 3, 1, 6, 192, 320), ] # w: width scaling # d: depth scaling # r: image resolution const efficientnet_global_configs = Dict( -# ( r, ( w, d)) - :b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6)) -) + # (r, (w, d)) + :b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) struct EfficientNet - layers::Any + layers::Any end """ @@ -103,27 +104,29 @@ See also [`efficientnet`](#). # Arguments -- `scalings`: global width and depth scaling (given as a tuple) -- `block_config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - `n`: number of block repetitions (will be scaled by global depth scaling) - - `k`: kernel size - - `s`: kernel stride - - `e`: expansion ratio - - `i`: block input channels (will be scaled by global width scaling) - - `o`: block output channels (will be scaled by global width scaling) -- `inchannels`: number of input channels -- `nclasses`: number of output classes -- `max_width`: maximum number of output channels before the fully connected - classification blocks + - `scalings`: global width and depth scaling (given as a tuple) + + - `block_config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `n`: number of block repetitions (will be scaled by global depth scaling) + + `k`: kernel size + + `s`: kernel stride + + `e`: expansion ratio + + `i`: block input channels (will be scaled by global width scaling) + + `o`: block output channels (will be scaled by global width scaling) + - `inchannels`: number of input channels + - `nclasses`: number of output classes + - `max_width`: maximum number of output channels before the fully connected + classification blocks """ function EfficientNet(scalings, block_config; inchannels = 3, nclasses = 1000, max_width = 1280) - layers = efficientnet(scalings, block_config; - inchannels = inchannels, - nclasses = nclasses, - max_width = max_width) - return EfficientNet(layers) + layers = efficientnet(scalings, block_config; + inchannels = inchannels, + nclasses = nclasses, + max_width = max_width) + return EfficientNet(layers) end @functor EfficientNet @@ -141,13 +144,13 @@ See also [`efficientnet`](#). # Arguments -- `name`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) -- `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `name`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) @assert name in keys(efficientnet_global_configs) - "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" + "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) pretrain && loadpretrain!(model, string("efficientnet-", name)) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index b7dfcd6f3..25067a631 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -28,7 +28,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). function mobilenetv1(width_mult, config; activation = relu, inchannels = 3, - fcsize = 1024, + fcsize = 1024, nclasses = 1000) layers = [] for (dw, outch, stride, nrepeats) in config diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index abf7193b1..50b75fbfd 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -8,18 +8,18 @@ function _drop_blocks(drop_block_prob = 0.0) ] end -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, +function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, first_dilation = nothing, norm_layer = BatchNorm) kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size first_dilation = kernel_size[1] > 1 ? (!isnothing(first_dilation) ? first_dilation : dilation) : 1 pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, dilation = first_dilation, bias = false), - norm_layer(out_channels)) + norm_layer(outplanes)) end -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, +function downsample_avg(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, first_dilation = nothing, norm_layer = BatchNorm) avg_stride = dilation == 1 ? stride : 1 if stride == 1 && dilation == 1 @@ -29,8 +29,8 @@ function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dila pool = MeanPool((2, 2); stride = avg_stride, pad) end return Chain(pool, - Conv((1, 1), in_channels => out_channels; bias = false), - norm_layer(out_channels)) + Conv((1, 1), inplanes => outplanes; bias = false), + norm_layer(outplanes)) end function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, @@ -78,16 +78,61 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina end expansion_factor(::typeof(bottleneck)) = 4 +function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, + norm_layer = BatchNorm, activation = relu) + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + # Main stem + inplanes = stem_type == :deep ? stem_width * 2 : 64 + if stem_type == :deep + stem_channels = (stem_width, stem_width) + if stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + bias = false), + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes, activation) + # Stem pooling + if replace_stem_pool + stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, + bias = false), + norm_layer(inplanes, activation)) + else + stempool = MaxPool((3, 3); stride = 2, pad = 1) + end + return Chain(conv1, bn1, stempool) +end + +function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), + stride = 1, dilation = 1, first_dilation = dilation, + norm_layer = BatchNorm) + if stride != 1 || inplanes != planes * expansion + downsample = downsample_fn(kernel_size, inplanes, planes * expansion; + stride, dilation, first_dilation, + norm_layer) + else + downsample = identity + end + return downsample +end + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function _make_blocks(block_fn, channels, block_repeats, inplanes; - reduce_first = 1, output_stride = 32, down_kernel_size = (1, 1), - avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - kwargs...) +function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, + downsample_fn = downsample_conv, downsample_args::NamedTuple = (), + drop_block_rate = 0.0, drop_path_rate = 0.0, + block_args::NamedTuple = ()) + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" expansion = expansion_factor(block_fn) - kwarg_dict = Dict(kwargs...) stages = [] net_block_idx = 1 net_stride = 4 @@ -103,17 +148,10 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; else net_stride *= stride end - # use average pooling for projection skip connection between stages/downsample. - downsample = identity - if stride != 1 || inplanes != planes * expansion - downsample_fn = avg_down ? downsample_avg : downsample_conv - downsample = downsample_fn(down_kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation = dilation, - norm_layer = kwarg_dict[:norm_layer]) - end - # arguments to be passed into the block function - block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => drop_block, kwargs...) + # Downsample block; either a (default) convolution-based block or a pooling-based block. + downsample = downsample_block(downsample_fn, inplanes, planes, expansion; + downsample_args...) + # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity @@ -123,7 +161,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, - drop_path = DropPath(block_dpr), block_kwargs...)) + drop_path = DropPath(block_dpr), drop_block, block_args...)) prev_dilation = dilation inplanes = planes * expansion net_block_idx += 1 @@ -133,103 +171,25 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; return Chain(stages...) end -""" - resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), - avg_down = false, activation = relu, norm_layer = BatchNorm, - drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, - block_kwargs...) - -Creates the layers of a ResNe(X)t model. If you are an end-user, you should probably use -[ResNet](@ref) instead and pass in the parameters you want to modify as optional parameters -there. - -# Arguments: - - - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for - example. - - - `layers`: A list of integers representing the number of blocks in each stage. - - `nclasses`: The number of output classes. The default value is 1000. - - `inchannels`: The number of input channels to the model. The default value is 3. - - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - This is used for [ResNeXt](@ref)-like models. The default value is 1. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels: `planes * base_width / 64 * cardinality`. - The default value is 64. - - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. - - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: - - + `:default` - a single 7x7 convolution layer with a width of `stem_width` - + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` - + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. - The default value is `:default`. - - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a - convolution layer. The default value is false. - - `reduce_first`: Reduction factor for first convolution output width of residual blocks, - Default is 1 for all architectures except SE-Nets, where it is 2. - - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the - skip connection. The default value is (1, 1) for all architectures except - SE-Nets, where it is (3, 3). - - `avg_down`: Use average pooling for projection skip connection between stages/downsample. - - `activation`: The activation function to use. The default value is `relu`. - - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. - - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `DropPath`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. - - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. - -If you are an end-user trying to tweak the ResNet model, note that there is no guarantee that -all combinations of parameters will work. In particular, tweaking `block_kwargs` is not -advised unless you know what you are doing. -""" function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), - avg_down = false, activation = relu, norm_layer = BatchNorm, - drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, - block_kwargs...) - @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" - @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + stem_fn = resnet_stem, stem_args::NamedTuple = (), + downsample_fn = downsample_conv, downsample_args::NamedTuple = (), + drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = 0.0), + block_args::NamedTuple = ()) # Stem - inplanes = stem_type == :deep ? stem_width * 2 : 64 - if stem_type == :deep - stem_channels = (stem_width, stem_width) - if stem_type == :deep_tiered - stem_channels = (3 * (stem_width ÷ 4), stem_width) - end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, - bias = false), - norm_layer(stem_channels[1], activation), - Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, - bias = false), - norm_layer(stem_channels[2], activation), - Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) - else - conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) - end - bn1 = norm_layer(inplanes, activation) - # Stem pooling - if replace_stem_pool - stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, - bias = false), - norm_layer(inplanes, activation)) - else - stempool = MaxPool((3, 3); stride = 2, pad = 1) - end - stem = Chain(conv1, bn1, stempool) + stem = stem_fn(; inchannels, stem_args...) # Feature Blocks channels = [64, 128, 256, 512] - stage_blocks = _make_blocks(block, channels, layers, inplanes; cardinality, base_width, - output_stride, reduce_first, avg_down, - down_kernel_size, activation, norm_layer, - drop_block_rate, drop_path_rate, block_kwargs...) + stage_blocks = _make_blocks(block, channels, layers, inchannels; + output_stride, downsample_fn, downsample_args, + drop_block_rate = drop_rates.drop_block_rate, + drop_path_rate = drop_rates.drop_path_rate, + block_args) # Head (Pooling and Classifier) expansion = expansion_factor(block) num_features = 512 * expansion - classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, + classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten, Dense(num_features, nclasses)) return Chain(Chain(stem, stage_blocks), classifier) end @@ -239,59 +199,6 @@ const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) - -""" - ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) - -Creates a ResNet model. -((reference)[https://arxiv.org/abs/1512.03385]) - -# Arguments: - - - `depth`: The depth of the `ResNet` model. Must be one of `[18, 34, 50, 101, 152]`. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. - - `nclasses`: The number of output classes. The default value is 1000. - -Apart from these, the model can also take any of the following optional arguments: - - - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for - example. - - - `layers`: A list of integers representing the number of blocks in each stage. - - `inchannels`: The number of input channels to the model. The default value is 3. - - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - This is used for [ResNeXt](@ref)-like models. The default value is 1. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels: `planes * base_width / 64 * cardinality`. - The default value is 64. - - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. - - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: - - + `:default` - a single 7x7 convolution layer with a width of `stem_width` - + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` - + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. - The default value is `:default`. - - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a - convolution layer. The default value is false. - - `reduce_first`: Reduction factor for first convolution output width of residual blocks, - Default is 1 for all architectures except SE-Nets, where it is 2. - - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the - skip connection. The default value is (1, 1) for all architectures except - SE-Nets, where it is (3, 3). - - `avg_down`: Use average pooling for projection skip connection between stages/downsample. - - `activation`: The activation function to use. The default value is `relu`. - - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. - - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. - - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. - -See also [`resnet`](@ref) for more details. - -!!! warning - - Pretrained models are not supported for all parameter combinations of `ResNet`. -""" struct ResNet layers::Any end @@ -300,78 +207,8 @@ end function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" model = resnet(resnet_config[depth]...; nclasses, kwargs...) - pretrain && loadpretrain!(model, string("resnet", depth)) - return model -end - -""" - ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, nclasses = 1000, - kwargs...) - -Creates a ResNeXt model. -((reference)[https://arxiv.org/abs/1611.05431]) - -# Arguments: - - - `depth`: The depth of the `ResNeXt` model. Must be one of `[50, 101, 152]`. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - of the `ResNeXt` mode. The default value is 4. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels. The default value is 32. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. - - `nclasses`: The number of output classes. The default value is 1000. - -Apart from these, the model can also take any of the following optional arguments: - - - `block`: The block to use in the ResNet model. See [basicblock](@ref) and [bottleneck](@ref) for - example. - - - `layers`: A list of integers representing the number of blocks in each stage. - - `inchannels`: The number of input channels to the model. The default value is 3. - - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32. - - `cardinality`: The number of convolution groups for the 3x3 convolution in the bottleneck block. - This is used for [ResNeXt](@ref)-like models. The default value is 1. - - `base_width`: The base width of each bottleneck block. It is the factor determining - the number of bottleneck channels: `planes * base_width / 64 * cardinality`. - The default value is 64. - - `stem_width`: The number of channels in the convolution in the stem. The default value is 64. - - `stem_type`: The type of the stem. Must be one of `[:default, :deep, :wide]`: - - + `:default` - a single 7x7 convolution layer with a width of `stem_width` - + `:deep` - three 3x3 convolution layers of widths `stem_width`, `stem_width`, `stem_width * 2` - + `:deep_tiered` - three 3x3 conv layers of widths `stem_width ÷ 4 * 3`, `stem_width`, `stem_width * 2`. - The default value is `:default`. - - `replace_stem_pool`: Whether to replace the pooling layer of the stem with a - convolution layer. The default value is false. - - `reduce_first`: Reduction factor for first convolution output width of residual blocks, - Default is 1 for all architectures except SE-Nets, where it is 2. - - `down_kernel_size`: The kernel size of the convolution in the downsample layer of the - skip connection. The default value is (1, 1) for all architectures except - SE-Nets, where it is (3, 3). - - `avg_down`: Use average pooling for projection skip connection between stages/downsample. - - `activation`: The activation function to use. The default value is `relu`. - - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`. - - `drop_rate`: The rate to use for `Dropout`. The default value is 0.0. - - `drop_path_rate`: The rate to use for `(DropPath`)[@ref]`. The default value is 0.0. - - `drop_block_rate`: The rate to use for `(DropBlock)[@ref]`. The default value is 0.0. - -See also [`resnet`](@ref) for more details. - -!!! warning - - Pretrained models are not currently supported for `ResNeXt`. -""" -struct ResNeXt - layers::Any -end -@functor ResNeXt - -function ResNeXt(depth::Integer; cardinality = 4, base_width = 32, pretrain = false, - nclasses = 1000, - kwargs...) - @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(resnet_config[depth]...; cardinality, base_width, nclasses, kwargs...) - pretrain && - loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + if pretrain + loadpretrain!(model, string("resnet", depth)) + end return model end From 3f45f277fab00071452bf51e868c46d63b722331 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 28 Jun 2022 22:58:04 +0530 Subject: [PATCH 32/64] Make `DropBlock` really work --- Project.toml | 3 +- src/Metalhead.jl | 4 +-- src/convnets/resne(x)t.jl | 16 ++++----- src/layers/Layers.jl | 2 ++ src/layers/drop.jl | 75 ++++++++++++++++++++++----------------- 5 files changed, 57 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index 42f12a887..c83b146a5 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,15 @@ version = "0.7.3" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 5463c64de..a9251150f 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -39,7 +39,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -48,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, :ResNet, +for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 50b75fbfd..eadc3d047 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -107,7 +107,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - return Chain(conv1, bn1, stempool) + return inplanes, Chain(conv1, bn1, stempool) end function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), @@ -150,7 +150,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride end # Downsample block; either a (default) convolution-based block or a pooling-based block. downsample = downsample_block(downsample_fn, inplanes, planes, expansion; - downsample_args...) + stride, dilation, first_dilation = dilation, downsample_args...) # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks @@ -172,16 +172,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride end function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - stem_fn = resnet_stem, stem_args::NamedTuple = (), - downsample_fn = downsample_conv, downsample_args::NamedTuple = (), + stem_fn = resnet_stem, stem_args::NamedTuple = NamedTuple(), + downsample_fn = downsample_conv, downsample_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - block_args::NamedTuple = ()) + drop_block_rate = 0.5), + block_args::NamedTuple = NamedTuple()) # Stem - stem = stem_fn(; inchannels, stem_args...) + inplanes, stem = stem_fn(; inchannels, stem_args...) # Feature Blocks channels = [64, 128, 256, 512] - stage_blocks = _make_blocks(block, channels, layers, inchannels; + stage_blocks = _make_blocks(block, channels, layers, inplanes; output_stride, downsample_fn, downsample_args, drop_block_rate = drop_rates.drop_block_rate, drop_path_rate = drop_rates.drop_path_rate, diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 6c417c077..1e75b53d6 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,12 +1,14 @@ module Layers using Flux +using CUDA using NNlib using NNlibCUDA using Functors using ChainRulesCore using Statistics using MLUtils +using Random include("../utilities.jl") diff --git a/src/layers/drop.jl b/src/layers/drop.jl index df66a3e7f..fdbdc7db7 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,28 +1,33 @@ -""" - DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) +function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale, active::Bool = true) where {T} + active || return x + H, W, _, _ = size(x) + total_size = H * W + clipped_block_size = min(block_size, min(H, W)) + gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / + ((W - block_size + 1) * (H - block_size + 1)) + block_mask = rand_like(rng, x) .< gamma + block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) + block_mask = 1 .- block_mask + normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) + return x .* block_mask .* normalize_scale +end +dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) +function dropblock(rng, x::CuArray, p; kwargs...) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays.")) +end -Implements DropBlock, a regularization method for convolutional networks. -([reference](https://arxiv.org/pdf/1810.12890.pdf)) -""" -struct DropBlock{F} +struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F block_size::Integer gamma_scale::F + active::Union{Bool, Nothing} + rng::R end -@functor DropBlock - -(m::DropBlock)(x) = dropblock(x, m.drop_block_prob, m.block_size, m.gamma_scale) -function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0) - if drop_block_prob == 0.0 - return identity - end - @assert drop_block_prob < 0 || drop_block_prob > 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert gamma_scale < 0 || gamma_scale > 1 - "gamma_scale must be between 0 and 1, got $gamma_scale" - return DropBlock(drop_block_prob, block_size, gamma_scale) -end +@functor DropBlock +trainable(a::DropBlock) = (;) function _dropblock_checks(x::T) where {T} if !(T <: AbstractArray) @@ -34,20 +39,26 @@ function _dropblock_checks(x::T) where {T} end ChainRulesCore.@non_differentiable _dropblock_checks(x) -function dropblock(x::AbstractArray{T, 4}, drop_block_prob, block_size, - gamma_scale) where {T} +function (m::DropBlock)(x) _dropblock_checks(x) - H, W, _, _ = size(x) - total_size = H * W - clipped_block_size = min(block_size, min(H, W)) - gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / - ((W - block_size + 1) * (H - block_size + 1)) - block_mask = rand_like(x) .< gamma - block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); - stride = 1, pad = clipped_block_size ÷ 2) - block_mask = 1 .- block_mask - normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) - return x .* block_mask .* normalize_scale + Flux._isactive(m) || return x + return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) +end + +function Flux.testmode!(m::DropBlock, mode = true) + return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end + +function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = Flux.rng_from_array()) + if drop_block_prob == 0.0 + return identity + end + @assert 0 ≤ drop_block_prob ≤ 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0 ≤ gamma_scale ≤ 1 + "gamma_scale must be between 0 and 1, got $gamma_scale" + return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) end """ From f373f459bb7486557a8ee031ca91cd6dab077855 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 29 Jun 2022 08:33:23 +0530 Subject: [PATCH 33/64] Construct the stem outside and pass it into `resnet` `downsample_args` is actually redundant --- src/Metalhead.jl | 1 - src/convnets/resne(x)t.jl | 61 +++++++++++++++++++-------------------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index a9251150f..34610c548 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,7 +7,6 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils -using Random import Functors diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index eadc3d047..4c77260f8 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -1,13 +1,3 @@ -# returns `DropBlock`s for each block of the ResNet -function _drop_blocks(drop_block_prob = 0.0) - return [ - identity, - identity, - DropBlock(drop_block_prob, 5, 0.25), - DropBlock(drop_block_prob, 3, 1.00), - ] -end - function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, first_dilation = nothing, norm_layer = BatchNorm) kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size @@ -80,12 +70,15 @@ expansion_factor(::typeof(bottleneck)) = 4 function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) - @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" + @assert stem_type in [:default, :deep, :deep_tiered] + "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem - inplanes = stem_type == :deep ? stem_width * 2 : 64 - if stem_type == :deep - stem_channels = (stem_width, stem_width) - if stem_type == :deep_tiered + deep_stem = stem_type == :deep || stem_type == :deep_tiered + inplanes = deep_stem ? stem_width * 2 : 64 + if deep_stem + if stem_type == :deep + stem_channels = (stem_width, stem_width) + elseif stem_type == :deep_tiered stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, @@ -107,7 +100,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - return inplanes, Chain(conv1, bn1, stempool) + return Chain(conv1, bn1, stempool), inplanes end function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), @@ -128,9 +121,8 @@ end # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, - downsample_fn = downsample_conv, downsample_args::NamedTuple = (), - drop_block_rate = 0.0, drop_path_rate = 0.0, - block_args::NamedTuple = ()) + downsample_fn = downsample_conv, + drop_rates::NamedTuple, block_args::NamedTuple) @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" expansion = expansion_factor(block_fn) stages = [] @@ -139,7 +131,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride dilation = prev_dilation = 1 for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, - _drop_blocks(drop_block_rate))) + _drop_blocks(drop_rates.drop_block_rate))) # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -148,16 +140,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride else net_stride *= stride end - # Downsample block; either a (default) convolution-based block or a pooling-based block. + # Downsample block; either a (default) convolution-based block or a pooling-based block downsample = downsample_block(downsample_fn, inplanes, planes, expansion; - stride, dilation, first_dilation = dilation, downsample_args...) + stride, dilation, first_dilation = dilation) # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 # stochastic depth linear decay rule - block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + block_dpr = drop_rates.drop_path_rate * net_block_idx / (sum(block_repeats) - 1) push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, @@ -171,21 +163,26 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride return Chain(stages...) end +# returns `DropBlock`s for each block of the ResNet +function _drop_blocks(drop_block_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_block_prob, 5, 0.25), + DropBlock(drop_block_prob, 3, 1.00), + ] +end + function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, - stem_fn = resnet_stem, stem_args::NamedTuple = NamedTuple(), - downsample_fn = downsample_conv, downsample_args::NamedTuple = NamedTuple(), + stem = first(resnet_stem(; inchannels)), inplanes = 64, + downsample_fn = downsample_conv, drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.5), + drop_block_rate = 0.0), block_args::NamedTuple = NamedTuple()) - # Stem - inplanes, stem = stem_fn(; inchannels, stem_args...) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = _make_blocks(block, channels, layers, inplanes; - output_stride, downsample_fn, downsample_args, - drop_block_rate = drop_rates.drop_block_rate, - drop_path_rate = drop_rates.drop_path_rate, - block_args) + output_stride, downsample_fn, drop_rates, block_args) # Head (Pooling and Classifier) expansion = expansion_factor(block) num_features = 512 * expansion From 51d07573d18cd69055ac6728f779e254faacea24 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 29 Jun 2022 18:40:00 +0530 Subject: [PATCH 34/64] Add ResNeXt back Also add tests. A lot of tests --- src/Metalhead.jl | 5 ++-- src/convnets/resne(x)t.jl | 28 ++++++++++++++----- src/layers/Layers.jl | 3 +-- src/layers/drop.jl | 4 +-- src/utilities.jl | 18 ------------- test/convnets.jl | 57 ++++++++++++++++++++++++++++----------- 6 files changed, 69 insertions(+), 46 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 34610c548..172f01d16 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,6 +7,7 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils +using Random import Functors @@ -38,7 +39,7 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -47,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt, +for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 4c77260f8..a596fc2d1 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -173,7 +173,7 @@ function _drop_blocks(drop_block_prob = 0.0) ] end -function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32, +function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, downsample_fn = downsample_conv, drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, @@ -181,10 +181,10 @@ function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = block_args::NamedTuple = NamedTuple()) # Feature Blocks channels = [64, 128, 256, 512] - stage_blocks = _make_blocks(block, channels, layers, inplanes; + stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fn, drop_rates, block_args) # Head (Pooling and Classifier) - expansion = expansion_factor(block) + expansion = expansion_factor(block_fn) num_features = 512 * expansion classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten, Dense(num_features, nclasses)) @@ -201,11 +201,27 @@ struct ResNet end @functor ResNet -function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...) - @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - model = resnet(resnet_config[depth]...; nclasses, kwargs...) +function ResNet(depth::Integer; pretrain = false, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + model = resnet(resnet_config[depth]...; nclasses) if pretrain loadpretrain!(model, string("resnet", depth)) end return model end + +struct ResNeXt + layers::Any +end +@functor ResNeXt + +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, block_args = (; cardinality, base_width)) + if pretrain + loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + end + return model +end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 1e75b53d6..efefd91b2 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -2,8 +2,7 @@ module Layers using Flux using CUDA -using NNlib -using NNlibCUDA +using NNlib, NNlibCUDA using Functors using ChainRulesCore using Statistics diff --git a/src/layers/drop.jl b/src/layers/drop.jl index fdbdc7db7..dbb7ddc34 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -13,8 +13,8 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, bl normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) return x .* block_mask .* normalize_scale end -dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) -function dropblock(rng, x::CuArray, p; kwargs...) +dropoutblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) +function dropblock(rng, x::CuArray, p, args...) throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays.")) end diff --git a/src/utilities.jl b/src/utilities.jl index 0c4f46796..930cc621a 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -9,24 +9,6 @@ function _round_channels(channels, divisor, min_value = divisor) return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels end -""" - addrelu(x, y) - -Convenience function for `(x, y) -> @. relu(x + y)`. -Useful as the `connection` argument for [`resnet`](#). -See also [`reluadd`](#). -""" -addrelu(x, y) = @. relu(x + y) - -""" - reluadd(x, y) - -Convenience function for `(x, y) -> @. relu(x) + relu(y)`. -Useful as the `connection` argument for [`resnet`](#). -See also [`addrelu`](#). -""" -reluadd(x, y) = @. relu(x) + relu(y) - """ cat_channels(x, y, zs...) diff --git a/test/convnets.jl b/test/convnets.jl index 11e472743..53006360b 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -27,18 +27,39 @@ GC.safepoint() GC.gc() @testset "ResNet" begin - @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] - m = ResNet(sz) - @test size(m(x_256)) == (1000, 1) - ## TODO: find a way to port pretrained models to the new ResNet API + # Tests for pretrained ResNets + ## TODO: find a way to port pretrained models to the new ResNet API + # @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] # if (ResNet, sz) in PRETRAINED_MODELS # @test acctest(ResNet(sz, pretrain = true)) # else # @test_throws ArgumentError ResNet(sz, pretrain = true) # end - @test gradtest(m, x_256) - GC.safepoint() - GC.gc() + # end + + @testset "resnet" begin + @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3] + ] + @testset for layers in layer_list + drop_list = [ + (drop_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), + (drop_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), + (drop_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), + ] + @testset for drop_rates in drop_list + m = Metalhead.resnet(block_fn, layers; drop_rates) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + GC.safepoint() + GC.gc() + end + end + end end end @@ -47,16 +68,20 @@ GC.gc() @testset "ResNeXt" begin @testset for depth in [50, 101, 152] - m = ResNeXt(depth) - @test size(m(x_224)) == (1000, 1) - if ResNeXt in PRETRAINED_MODELS - @test acctest(ResNeXt(depth, pretrain = true)) - else - @test_throws ArgumentError ResNeXt(depth, pretrain = true) + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if string("resnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + GC.safepoint() + GC.gc() + end end - @test gradtest(m, x_224) - GC.safepoint() - GC.gc() end end From 106f26053175204a38431856bdaef49226c04c26 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 29 Jun 2022 19:54:22 +0530 Subject: [PATCH 35/64] Add more general implementation of SE layer Also 1. Tweaks - II : Formatting + some docs 2. Groundwork for abstracting out the classifier --- src/convnets/convmixer.jl | 4 +- src/convnets/convnext.jl | 2 +- src/convnets/inception.jl | 51 ++++++++++---------- src/convnets/mobilenet.jl | 12 ++--- src/convnets/resne(x)t.jl | 99 +++++++++++++++++++++++++++++++++------ src/convnets/vgg.jl | 28 +++++------ src/layers/Layers.jl | 8 ++-- src/layers/classifier.jl | 12 +++++ src/layers/conv.jl | 25 ++-------- src/layers/embeddings.jl | 2 +- src/layers/mlp-linear.jl | 21 +++++---- src/layers/pool.jl | 26 ++++++++++ src/layers/selayers.jl | 41 ++++++++++++++++ src/other/mlpmixer.jl | 28 ++++++----- src/vit-based/vit.jl | 20 ++++---- test/convnets.jl | 6 +-- 16 files changed, 260 insertions(+), 125 deletions(-) create mode 100644 src/layers/classifier.jl create mode 100644 src/layers/pool.jl create mode 100644 src/layers/selayers.jl diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index f70473ca5..d36f1a8d5 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -9,7 +9,7 @@ Creates a ConvMixer model. - `planes`: number of planes in the output of each block - `depth`: number of layers - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `kernel_size`: kernel size of the convolutional layers - `patch_size`: size of the patches - `activation`: activation function used after the convolutional layers @@ -45,7 +45,7 @@ Creates a ConvMixer model. # Arguments - `mode`: the mode of the model, either `:base`, `:small` or `:large` - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `activation`: activation function used after the convolutional layers - `nclasses`: number of classes in the output """ diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 3fef58d1d..f3da6dbf3 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -92,7 +92,7 @@ Creates a ConvNeXt model. # Arguments: - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `drop_path_rate`: Stochastic depth rate. - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239) - `nclasses`: number of output classes diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index ba30fa86f..156362cf3 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -281,7 +281,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) + inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -289,10 +289,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) +function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -315,13 +315,13 @@ function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) inceptionv4_c(), inceptionv4_c(), inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(1536, nclasses)) return Chain(body, head) end """ - Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) + Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -330,7 +330,7 @@ Creates an Inceptionv4 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -341,8 +341,9 @@ struct Inceptionv4 layers::Any end -function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionv4(; inchannels, drop_rate, nclasses) +function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionv4(; inchannels, dropout_rate, nclasses) if pretrain loadpretrain!(layers, "Inceptionv4") end @@ -424,7 +425,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -432,10 +433,10 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) +function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -451,13 +452,13 @@ function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), conv_bn((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(1536, nclasses)) return Chain(body, head) end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -466,7 +467,7 @@ Creates an InceptionResNetv2 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -477,9 +478,9 @@ struct InceptionResNetv2 layers::Any end -function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, +function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) + layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) if pretrain loadpretrain!(layers, "InceptionResNetv2") end @@ -504,7 +505,7 @@ Create an Xception block. # Arguments - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `outchannels`: number of output channels. - `nrepeats`: number of repeats of depthwise separable convolution layers. - `stride`: stride by which to downsample the input. @@ -541,7 +542,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000) + xception(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -549,10 +550,10 @@ Creates an Xception model. # Arguments - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) +function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., conv_bn((3, 3), 32, 64; bias = false)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), @@ -562,7 +563,7 @@ function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) xception_block(728, 1024, 2; stride = 2, grow_at_start = false), depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(2048, nclasses)) return Chain(body, head) end @@ -572,7 +573,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -581,15 +582,15 @@ Creates an Xception model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - `inchannels`: number of input channels. - - `drop_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning `Xception` does not currently support pretrained weights. """ -function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = xception(; inchannels, drop_rate, nclasses) +function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, dropout_rate, nclasses) if pretrain loadpretrain!(layers, "xception") end diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 25067a631..15dc037e8 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -21,7 +21,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `fcsize`: The intermediate fully-connected size between the convolution and final layers - `nclasses`: The number of output classes """ @@ -77,7 +77,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet - `nclasses`: The number of output classes @@ -123,7 +123,7 @@ Create a MobileNetv2 model. + `n`: The number of times a block is repeated + `s`: The stride of the convolutional kernel + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes """ @@ -181,7 +181,7 @@ Set `pretrain` to `true` to load the pretrained weights for ImageNet. - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `pretrain`: Whether to load the pre-trained weights for ImageNet - `nclasses`: The number of output classes @@ -226,7 +226,7 @@ Create a MobileNetv3 model. + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + `s::Integer` - The stride of the convolutional kernel + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. The default value is 3. + - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes """ @@ -312,7 +312,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `pretrain`: whether to load the pre-trained weights for ImageNet - `nclasses`: the number of output classes diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index a596fc2d1..74140d625 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -68,6 +68,34 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina end expansion_factor(::typeof(bottleneck)) = 4 +""" + resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, + norm_layer = BatchNorm, activation = relu) + +Builds a stem to be used in a ResNet model. See the `stem` argument of `resnet` for details +on how to use this function. + +# Arguments: + + - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + + + `:default`: Builds a stem based on the default ResNet stem, which consists of a single + 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 + max pooling layer with stride 2. + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2 for one) in using a + deeper stem with 3 successive 3x3 convolutions having normalisation layers + after each one. This is followed by a 3x3 max pooling layer with stride 2. + + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second + convolution. This is an experimental variant from the `timm` library + in Python that shows peformance improvements over the `:deep` stem + in some cases. + + - `inchannels`: The number of channels in the input. + - `replace_stem_pool`: Whether to replace the default 3x3 max pooling layer with a + 3x3 convolution with stride 2 and a normalisation layer. + - `norm_layer`: The normalisation layer used in the stem. + - `activation`: The activation function used in the stem. +""" function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) @assert stem_type in [:default, :deep, :deep_tiered] @@ -75,13 +103,14 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = # Main stem deep_stem = stem_type == :deep || stem_type == :deep_tiered inplanes = deep_stem ? stem_width * 2 : 64 + # Deep stem that uses three successive 3x3 convolutions instead of a single 7x7 convolution if deep_stem if stem_type == :deep stem_channels = (stem_width, stem_width) elseif stem_type == :deep_tiered stem_channels = (3 * (stem_width ÷ 4), stem_width) end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, + conv1 = Chain(Conv((3, 3), inchannels => stem_channels[1]; stride = 2, pad = 1, bias = false), norm_layer(stem_channels[1], activation), Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, @@ -129,9 +158,10 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 + dbr = haskey(drop_rates, :drop_block_rate) ? drop_rates.drop_block_rate : 0 for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, - _drop_blocks(drop_rates.drop_block_rate))) + _drop_blocks(dbr))) # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -149,7 +179,8 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 # stochastic depth linear decay rule - block_dpr = drop_rates.drop_path_rate * net_block_idx / (sum(block_repeats) - 1) + dpr = haskey(drop_rates, :drop_path_rate) ? drop_rates.drop_path_rate : 0 + block_dpr = dpr * net_block_idx / (sum(block_repeats) - 1) push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, @@ -163,22 +194,20 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride return Chain(stages...) end -# returns `DropBlock`s for each block of the ResNet +# returns `DropBlock`s for each stage of the ResNet function _drop_blocks(drop_block_prob = 0.0) return [ - identity, - identity, - DropBlock(drop_block_prob, 5, 0.25), - DropBlock(drop_block_prob, 3, 1.00), + identity, identity, + DropBlock(drop_block_prob, 5, 0.25), DropBlock(drop_block_prob, 3, 1.00), ] end function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, - downsample_fn = downsample_conv, - drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0, + downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), + drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), - block_args::NamedTuple = NamedTuple()) + classifier_args::NamedTuple = NamedTuple()) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; @@ -186,11 +215,13 @@ function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride # Head (Pooling and Classifier) expansion = expansion_factor(block_fn) num_features = 512 * expansion - classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten, - Dense(num_features, nclasses)) + global_pool, fc = create_classifier(num_features, nclasses; classifier_args...) + dr = haskey(drop_rates, :dropout_rate) ? drop_rates.dropout_rate : 0 + classifier = Chain(global_pool, Dropout(dr), fc) return Chain(Chain(stem, stage_blocks), classifier) end +# block-layer configurations for ResNet and ResNeXt models const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), 34 => (basicblock, [3, 4, 6, 3]), 50 => (bottleneck, [3, 4, 6, 3]), @@ -201,6 +232,23 @@ struct ResNet end @functor ResNet +""" + ResNet(depth::Integer; pretrain = false, nclasses = 1000) + +Creates a ResNet model with the specified depth. + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +!!! warning + + `ResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" function ResNet(depth::Integer; pretrain = false, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" @@ -216,10 +264,31 @@ struct ResNeXt end @functor ResNeXt -function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) +""" + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + +Creates a ResNeXt model with the specified depth, cardinality, and base width. + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `nclasses`: the number of output classes + +!!! warning + + `ResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, block_args = (; cardinality, base_width)) + model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, + block_args = (; cardinality, base_width)) if pretrain loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 15560de7c..957a0a483 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -52,7 +52,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) + vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -63,19 +63,19 @@ Create VGG classifier (fully connected) layers the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `drop_rate`: the dropout level between each fully connected layer + - `dropout_rate`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) +function vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) return Chain(MLUtils.flatten, Dense(Int(prod(imsize)), fcsize, relu), - Dropout(drop_rate), + Dropout(dropout_rate), Dense(fcsize, fcsize, relu), - Dropout(drop_rate), + Dropout(dropout_rate), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -90,12 +90,12 @@ Create a VGG model - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `drop_rate`: dropout level between fully connected layers + - `dropout_rate`: dropout level between fully connected layers """ -function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) +function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) + class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) return Chain(Chain(conv), class) end @@ -114,7 +114,7 @@ struct VGG end """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -126,11 +126,11 @@ Construct a VGG model with the specified input image size. Typically, the image - `nclasses`::Integer : number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `drop_rate`: dropout level between fully connected layers + - `dropout_rate`: dropout level between fully connected layers """ function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, - drop_rate) - layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, drop_rate) + dropout_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, dropout_rate) return VGG(layers) end @@ -159,7 +159,7 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses batchnorm = batchnorm, nclasses = nclasses, fcsize = 4096, - drop_rate = 0.5) + dropout_rate = 0.5) if pretrain && !batchnorm loadpretrain!(model, string("vgg", depth)) elseif pretrain diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index efefd91b2..f58f40172 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -17,14 +17,16 @@ include("mlp-linear.jl") include("normalise.jl") include("conv.jl") include("drop.jl") +include("selayers.jl") +include("classifier.jl") export MHAttention, PatchEmbedding, ViPosEmbedding, ClassTokens, mlp_block, gated_mlp_block, - LayerScale, DropPath, + LayerScale, DropPath, DropBlock, ChannelLayerNorm, prenorm, skip_identity, skip_projection, conv_bn, depthwise_sep_conv_bn, - invertedresidual, squeeze_excite, - DropBlock + squeeze_excite, effective_squeeze_excite, + invertedresidual, create_classifier end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl new file mode 100644 index 000000000..04be6ef86 --- /dev/null +++ b/src/layers/classifier.jl @@ -0,0 +1,12 @@ +function create_classifier(inplanes, nclasses; pool_type = :avg, use_conv = false) + flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling + if pool_type == :identity + @assert use_conv + "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" + flatten_in_pool = false # disable flattening if pooling is pass-through (no pooling) + end + global_pool = SelectAdaptivePool(; pool_type, flatten = flatten_in_pool) + fc = use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : + Dense(inplanes => nclasses; bias = true) + return global_pool, fc +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 6363946d0..e56967aef 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -144,27 +144,6 @@ function skip_identity(inplanes, outplanes) end skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) -""" - squeeze_excite(channels, reduction = 4) - -Squeeze and excitation layer used by MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `channels`: the number of input/output feature maps - - `reduction = 4`: the reduction factor for the number of hidden feature maps - (must be ≥ 1) -""" -function squeeze_excite(channels, reduction = 4) - @assert (reduction>=1) "`reduction` must be >= 1" - return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - conv_bn((1, 1), channels, channels ÷ reduction, relu; - bias = false)..., - conv_bn((1, 1), channels ÷ reduction, channels, hardσ)...), - .*) -end - """ invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; stride, reduction = nothing) @@ -190,7 +169,9 @@ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, pad = @. (kernel_size - 1) ÷ 2 conv1 = (inplanes == hidden_planes) ? identity : Chain(conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false)) - selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes, reduction) + selayer = isnothing(reduction) ? identity : + squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, + norm_layer = BatchNorm) invres = Chain(conv1, conv_bn(kernel_size, hidden_planes, hidden_planes, activation; bias = false, stride, pad = pad, groups = hidden_planes)..., diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index ad079db9d..66f25d1c0 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -11,7 +11,7 @@ patches. # Arguments: - `imsize`: the size of the input image - - `inchannels`: the number of channels in the input. The default value is 3. + - `inchannels`: the number of channels in the input. - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding - `norm_layer`: the normalization layer - by default the identity function but otherwise takes a diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp-linear.jl index 550c2ad22..8cca1e266 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp-linear.jl @@ -15,7 +15,7 @@ end """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - drop_rate =0., activation = gelu) + dropout_rate =0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -24,18 +24,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `drop_rate`: Dropout rate. + - `dropout_rate`: Dropout rate. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - drop_rate = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(drop_rate), - Dense(hidden_planes, outplanes), Dropout(drop_rate)) + dropout_rate = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout_rate), + Dense(hidden_planes, outplanes), Dropout(dropout_rate)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) + outplanes::Integer = inplanes; dropout_rate = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -46,16 +46,17 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `drop_rate`: Dropout rate. + - `dropout_rate`: Dropout rate. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) + outplanes::Integer = inplanes; dropout_rate = 0.0, + activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(drop_rate), + Dropout(dropout_rate), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(drop_rate)) + Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/layers/pool.jl b/src/layers/pool.jl new file mode 100644 index 000000000..aa5755240 --- /dev/null +++ b/src/layers/pool.jl @@ -0,0 +1,26 @@ +function AdaptiveMeanMaxPool(output_size = (1, 1)) + return 0.5 * Parallel(.+, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) +end + +function AdaptiveCatMeanMaxPool(output_size = (1, 1)) + return Parallel(cat_channels, AdaptiveAvgMaxPool(output_size), + AdaptiveMaxPool(output_size)) +end + +function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) + if pool_type == :mean + pool = AdaptiveAvgPool(output_size) + elseif pool_type == :max + pool = AdaptiveMaxPool(output_size) + elseif pool_type == :meanmax + pool = AdaptiveAvgMaxPool(output_size) + elseif pool_type == :catmeanmax + pool = AdaptiveCatAvgMaxPool(output_size) + elseif pool_type = :identity + pool = identity + else + throw(AssertionError("Invalid pool type: $pool_type")) + end + flatten_fn = flatten ? MLUtils.flatten : identity + return Chain(pool, flatten_fn) +end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl new file mode 100644 index 000000000..acd7e9809 --- /dev/null +++ b/src/layers/selayers.jl @@ -0,0 +1,41 @@ +""" + squeeze_excite(inplanes, reduction = 16; rd_divisor = 8, + activation = relu, gate_activation = sigmoid, norm_layer = identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + +Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. + +# Arguments + + - `inplanes`: The number of input feature maps + - `reduction`: The reduction factor for the number of hidden feature maps + - `rd_divisor`: The divisor for the number of hidden feature maps. + - `activation`: The activation function for the first convolution layer + - `gate_activation`: The activation function for the gate layer + - `norm_layer`: The normalization layer to be used after the convolution layers + - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer + Must be ≥ 1 or `nothing` for no squeeze and excite layer. +""" +function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, + activation = relu, gate_activation = sigmoid, norm_layer = identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes => rd_planes), + norm_layer, + activation, + Conv((1, 1), rd_planes => inplanes), + norm_layer, + gate_activation), .*) +end + +""" + effective_squeeze_excite(inplanes, gate_layer = sigmoid) + +Effective squeeze-and-excitation layer. +(reference: [CenterMask : Real-Time Anchor-Free Instance Segmentation](https://arxiv.org/abs/1911.06667)) +""" +function effective_squeeze_excite(inplanes; gate_activation = sigmoid, kwargs...) + return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes, inplanes), + gate_activation), .*) +end diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index ed4c47af3..5083b228e 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - drop_rate =0., drop_path_rate = 0., activation = gelu) + dropout_rate =0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -12,22 +12,22 @@ Creates a feedforward block for the MLPMixer architecture. - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP and/or the channel mixing MLP as a ratio to the number of planes in the block. - `mlp_layer`: the MLP layer to use in the block - - `drop_rate`: the dropout rate to use in the MLP blocks + - `dropout_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] return Chain(SkipConnection(Chain(LayerNorm(planes), swapdims((2, 1, 3)), mlp_layer(npatches, tokenplanes; activation, - drop_rate), + dropout_rate), swapdims((2, 1, 3)), DropPath(drop_path_rate)), +), SkipConnection(Chain(LayerNorm(planes), mlp_layer(planes, channelplanes; activation, - drop_rate), + dropout_rate), DropPath(drop_path_rate)), +)) end @@ -115,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; drop_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; dropout_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. @@ -128,13 +128,14 @@ Creates a block for the ResMixer architecture. - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `mlp_layer`: the MLP block to use - - `drop_rate`: the dropout rate to use in the MLP blocks + - `dropout_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks - `λ`: initialisation constant for the LayerScale """ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, + λ = 1e-4) return Chain(SkipConnection(Chain(Flux.Scale(planes), swapdims((2, 1, 3)), Dense(npatches, npatches), @@ -142,7 +143,8 @@ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, LayerScale(planes, λ), DropPath(drop_path_rate)), +), SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); drop_rate, + mlp_layer(planes, Int(mlp_ratio * planes); + dropout_rate, activation), LayerScale(planes, λ), DropPath(drop_path_rate)), +)) @@ -232,7 +234,7 @@ end """ spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, drop_rate = 0.0, drop_path_rate = 0.0, + norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. @@ -245,19 +247,19 @@ Creates a feedforward block based on the gMLP model architecture described in th - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `norm_layer`: the normalisation layer to use - - `drop_rate`: the dropout rate to use in the MLP blocks + - `dropout_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, drop_rate = 0.0, + mlp_layer = gated_mlp_block, dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) channelplanes = Int(mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), mlp_layer(sgu, planes, channelplanes; activation, - drop_rate), + dropout_rate), DropPath(drop_path_rate)), +) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 686ddc4d5..a06ce6886 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate =0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate =0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,23 +10,23 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `drop_rate`: dropout rate + - `dropout_rate`: dropout rate """ -function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate = 0.0) +function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, MHAttention(planes, nheads; - attn_drop_rate = drop_rate, - proj_drop_rate = drop_rate)), +), + attn_drop_rate = dropout_rate, + proj_drop_rate = dropout_rate)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - drop_rate)), +)) + dropout_rate)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, emb_drop_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. @@ -41,13 +41,13 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `drop_rate`: dropout rate + - `dropout_rate`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" @@ -57,7 +57,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = ViPosEmbedding(embedplanes, npatches + 1), Dropout(emb_drop_rate), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, - drop_rate), + dropout_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end diff --git a/test/convnets.jl b/test/convnets.jl index 53006360b..a716e9f9d 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -47,9 +47,9 @@ GC.gc() ] @testset for layers in layer_list drop_list = [ - (drop_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), - (drop_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), - (drop_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), + (dropout_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), + (dropout_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), + (dropout_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), ] @testset for drop_rates in drop_list m = Metalhead.resnet(block_fn, layers; drop_rates) From 71473094f64d2a5ea8ba9258db98ea54b596bb25 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 1 Jul 2022 12:47:22 +0530 Subject: [PATCH 36/64] Tweaks III + Some more docs 1. Reorganise layer imports for easy access 2. Get pooling to work --- src/convnets/resne(x)t.jl | 17 ++++---- src/layers/Layers.jl | 30 +++++++++---- src/layers/classifier.jl | 2 +- src/layers/drop.jl | 88 ++++++++++++++++++++++++++++++++++----- src/layers/pool.jl | 8 ++-- src/layers/selayers.jl | 18 +++++--- 6 files changed, 125 insertions(+), 38 deletions(-) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 74140d625..424c0dd55 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -158,10 +158,14 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 - dbr = haskey(drop_rates, :drop_block_rate) ? drop_rates.drop_block_rate : 0 + # Stochastic depth linear decay rule (DropPath) + dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), + sum(block_repeats)) for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, - _drop_blocks(dbr))) + _drop_blocks(get(drop_rates, + :drop_block_rate, + 0)))) # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -178,13 +182,11 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 - # stochastic depth linear decay rule - dpr = haskey(drop_rates, :drop_path_rate) ? drop_rates.drop_path_rate : 0 - block_dpr = dpr * net_block_idx / (sum(block_repeats) - 1) push!(blocks, block_fn(inplanes, planes; stride, downsample, first_dilation = prev_dilation, - drop_path = DropPath(block_dpr), drop_block, block_args...)) + drop_path = DropPath(dp_rates[block_idx]), drop_block, + block_args...)) prev_dilation = dilation inplanes = planes * expansion net_block_idx += 1 @@ -216,8 +218,7 @@ function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride expansion = expansion_factor(block_fn) num_features = 512 * expansion global_pool, fc = create_classifier(num_features, nclasses; classifier_args...) - dr = haskey(drop_rates, :dropout_rate) ? drop_rates.dropout_rate : 0 - classifier = Chain(global_pool, Dropout(dr), fc) + classifier = Chain(global_pool, Dropout(get(drop_rates, :dropout_rate, 0)), fc) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index f58f40172..41a98843e 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,6 +1,7 @@ module Layers using Flux +using Flux: rng_from_array using CUDA using NNlib, NNlibCUDA using Functors @@ -12,21 +13,32 @@ using Random include("../utilities.jl") include("attention.jl") +export MHAttention + include("embeddings.jl") +export PatchEmbedding, ViPosEmbedding, ClassTokens + include("mlp-linear.jl") +export mlp_block, gated_mlp_block, LayerScale + include("normalise.jl") +export prenorm, ChannelLayerNorm + include("conv.jl") +export conv_bn, depthwise_sep_conv_bn, invertedresidual +skip_identity, skip_projection + include("drop.jl") +export DropPath, DropBlock + include("selayers.jl") +export squeeze_excite, effective_squeeze_excite + include("classifier.jl") +export create_classifier + +include("pool.jl") +export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool +SelectAdaptivePool -export MHAttention, - PatchEmbedding, ViPosEmbedding, ClassTokens, - mlp_block, gated_mlp_block, - LayerScale, DropPath, DropBlock, - ChannelLayerNorm, prenorm, - skip_identity, skip_projection, - conv_bn, depthwise_sep_conv_bn, - squeeze_excite, effective_squeeze_excite, - invertedresidual, create_classifier end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl index 04be6ef86..e2a1fe75c 100644 --- a/src/layers/classifier.jl +++ b/src/layers/classifier.jl @@ -1,4 +1,4 @@ -function create_classifier(inplanes, nclasses; pool_type = :avg, use_conv = false) +function create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling if pool_type == :identity @assert use_conv diff --git a/src/layers/drop.jl b/src/layers/drop.jl index dbb7ddc34..c89bd55fe 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,6 +1,27 @@ +""" + dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale, active::Bool = true) + +The dropblock function. If `active` is `true`, for each input, it zeroes out continguous +regions of size `block_size` in the input. Otherwise, it simply returns the input `x`. + +# Arguments + + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. + - `x`: input array + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + refer to [the paper](https://arxiv.org/abs/1810.12890). + +If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. +""" function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale, active::Bool = true) where {T} - active || return x + if !active + return x + end H, W, _, _ = size(x) total_size = H * W clipped_block_size = min(block_size, min(H, W)) @@ -13,12 +34,14 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, bl normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) return x .* block_mask .* normalize_scale end -dropoutblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) + +dropblock(x, p, args...) = dropblock(rng_from_array(x), x, p, args...) +dropblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) function dropblock(rng, x::CuArray, p, args...) - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays.")) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports CUDA.RNG for CuArrays.")) end -struct DropBlock{F, R <: AbstractRNG} +mutable struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F block_size::Integer gamma_scale::F @@ -41,16 +64,36 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x) function (m::DropBlock)(x) _dropblock_checks(x) - Flux._isactive(m) || return x - return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) + if Flux._isactive(m) + return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) + else + return x + end end function Flux.testmode!(m::DropBlock, mode = true) return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) end +""" + DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = rng_from_array()) + +The `DropBlock` layer. While training, it zeroes out continguous regions of +size `block_size` in the input. During inference, it simply returns the input `x`. +((reference)[https://arxiv.org/abs/1810.12890]) + +# Arguments + + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + refer to [the paper](https://arxiv.org/abs/1810.12890). + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. +""" function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, - rng = Flux.rng_from_array()) + rng = rng_from_array()) if drop_block_prob == 0.0 return identity end @@ -61,15 +104,40 @@ function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) end +function Base.show(io::IO, d::DropBlock) + print(io, "DropBlock(", d.drop_block_prob) + print(io, ", block_size = $(repr(d.block_size))") + print(io, ", gamma_scale = $(repr(d.gamma_scale))") + return print(io, ")") +end + """ - DropPath(p) + DropPath(p; [rng = rng_from_array(x)]) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0 and +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and `identity` otherwise. ([reference](https://arxiv.org/abs/1603.09382)) +This layer can be used to drop certain blocks in a residual structure and allow them to +propagate completely through the skip connection. It can be used in two ways: either with +all blocks having the same survival probability or with a linear scaling rule across the +blocks. This is performed only at training time. At test time, the `DropPath` layer is +equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropPath` returning `identity`, which may not be desirable. This usually happens with + a low number of blocks and a high base survival probability, so it is recommended to + use a fixed base survival probability across blocks. If this is not possible, then + a lower base survival probability is recommended. + # Arguments - `p`: rate of Stochastic Depth. + - `rng`: can be used to pass in a custom RNG instead of the default. See `Flux.Dropout` + for more information on the behaviour of this argument. Custom RNGs are only supported + on the CPU. """ -DropPath(p) = p > 0 ? Dropout(p; dims = 4) : identity +DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity diff --git a/src/layers/pool.jl b/src/layers/pool.jl index aa5755240..4ffe298e3 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -9,14 +9,14 @@ end function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) if pool_type == :mean - pool = AdaptiveAvgPool(output_size) + pool = AdaptiveMeanPool(output_size) elseif pool_type == :max pool = AdaptiveMaxPool(output_size) elseif pool_type == :meanmax - pool = AdaptiveAvgMaxPool(output_size) + pool = AdaptiveMeanMaxPool(output_size) elseif pool_type == :catmeanmax - pool = AdaptiveCatAvgMaxPool(output_size) - elseif pool_type = :identity + pool = AdaptiveCatMeanMaxPool(output_size) + elseif pool_type == :identity pool = identity else throw(AssertionError("Invalid pool type: $pool_type")) diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index acd7e9809..7f1a76d59 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -1,5 +1,5 @@ """ - squeeze_excite(inplanes, reduction = 16; rd_divisor = 8, + squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, activation = relu, gate_activation = sigmoid, norm_layer = identity, rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) @@ -17,22 +17,28 @@ Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, - activation = relu, gate_activation = sigmoid, norm_layer = identity, - rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + activation = relu, gate_activation = sigmoid, + norm_layer = planes -> identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), Conv((1, 1), inplanes => rd_planes), - norm_layer, + norm_layer(rd_planes), activation, Conv((1, 1), rd_planes => inplanes), - norm_layer, + norm_layer(inplanes), gate_activation), .*) end """ - effective_squeeze_excite(inplanes, gate_layer = sigmoid) + effective_squeeze_excite(inplanes, gate_activation = sigmoid) Effective squeeze-and-excitation layer. (reference: [CenterMask : Real-Time Anchor-Free Instance Segmentation](https://arxiv.org/abs/1911.06667)) + +# Arguments + + - `inplanes`: The number of input feature maps + - `gate_activation`: The activation function for the gate layer """ function effective_squeeze_excite(inplanes; gate_activation = sigmoid, kwargs...) return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), From 7ed20d45e65c62034b76ec2b78b57c4db2dca30f Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 3 Jul 2022 10:34:41 +0530 Subject: [PATCH 37/64] Fix `DropBlock` on the GPU --- src/convnets/resne(x)t.jl | 2 +- src/layers/drop.jl | 43 ++++++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resne(x)t.jl index 424c0dd55..03ed646c0 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resne(x)t.jl @@ -288,7 +288,7 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, + model = resnet(resnet_config[depth]...; nclasses, block_args = (; cardinality, base_width)) if pretrain loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) diff --git a/src/layers/drop.jl b/src/layers/drop.jl index c89bd55fe..dc6cb3c54 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,3 +1,11 @@ +# Generates the mask to be used for `DropBlock` +@inline function _dropblock_mask(rng, x, gamma, clipped_block_size) + block_mask = Flux.f32(rand_like(rng, x) .< gamma) + return 1 .- maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) +end +ChainRulesCore.@non_differentiable _dropblock_mask(rng, x, gamma, clipped_block_size) + """ dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale, active::Bool = true) @@ -18,28 +26,25 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. """ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, - gamma_scale, active::Bool = true) where {T} - if !active - return x - end + gamma_scale) where {T} H, W, _, _ = size(x) total_size = H * W clipped_block_size = min(block_size, min(H, W)) gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / ((W - block_size + 1) * (H - block_size + 1)) - block_mask = rand_like(rng, x) .< gamma - block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size); - stride = 1, pad = clipped_block_size ÷ 2) - block_mask = 1 .- block_mask + block_mask = dropblock_mask(rng, x, gamma, clipped_block_size) normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) return x .* block_mask .* normalize_scale end -dropblock(x, p, args...) = dropblock(rng_from_array(x), x, p, args...) -dropblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...) -function dropblock(rng, x::CuArray, p, args...) +## bs is `clipped_block_size` +# Dispatch for GPU +dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) +function dropblock_mask(rng, x::CuArray, gamma, bs) throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports CUDA.RNG for CuArrays.")) end +# Dispatch for CPU +dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) mutable struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F @@ -52,7 +57,11 @@ end @functor DropBlock trainable(a::DropBlock) = (;) -function _dropblock_checks(x::T) where {T} +function _dropblock_checks(x::T, drop_block_prob, gamma_scale) where {T} + @assert 0 ≤ drop_block_prob ≤ 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0 ≤ gamma_scale ≤ 1 + "gamma_scale must be between 0 and 1, got $gamma_scale" if !(T <: AbstractArray) throw(ArgumentError("x must be an `AbstractArray`")) end @@ -60,10 +69,10 @@ function _dropblock_checks(x::T) where {T} throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) end end -ChainRulesCore.@non_differentiable _dropblock_checks(x) +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_scale) function (m::DropBlock)(x) - _dropblock_checks(x) + _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) if Flux._isactive(m) return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) else @@ -87,7 +96,7 @@ size `block_size` in the input. During inference, it simply returns the input `x - `drop_block_prob`: probability of dropping a block - `block_size`: size of the block to drop - - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. @@ -97,10 +106,6 @@ function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, if drop_block_prob == 0.0 return identity end - @assert 0 ≤ drop_block_prob ≤ 1 - "drop_block_prob must be between 0 and 1, got $drop_block_prob" - @assert 0 ≤ gamma_scale ≤ 1 - "gamma_scale must be between 0 and 1, got $gamma_scale" return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) end From f0051b70b0b18840a478b2304d53e83afbf907ac Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 3 Jul 2022 11:57:23 +0530 Subject: [PATCH 38/64] Add `SEResNet` and `SEResNeXt` So much GC, might as well have a function for it --- src/Metalhead.jl | 5 +- src/convnets/{resne(x)t.jl => resnets.jl} | 96 ++++++++++++++++--- test/convnets.jl | 107 ++++++++++------------ test/other.jl | 9 +- test/runtests.jl | 5 + test/vit-based.jl | 3 +- 6 files changed, 145 insertions(+), 80 deletions(-) rename src/convnets/{resne(x)t.jl => resnets.jl} (80%) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 172f01d16..a4dd73785 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,7 +22,7 @@ include("convnets/alexnet.jl") include("convnets/vgg.jl") include("convnets/inception.jl") include("convnets/googlenet.jl") -include("convnets/resne(x)t.jl") +include("convnets/resnets.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") @@ -43,6 +43,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, + SEResNet, SEResNeXt, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt @@ -50,7 +51,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, + :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :SEResNet, :SEResNeXt, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/resne(x)t.jl b/src/convnets/resnets.jl similarity index 80% rename from src/convnets/resne(x)t.jl rename to src/convnets/resnets.jl index 03ed646c0..73f070617 100644 --- a/src/convnets/resne(x)t.jl +++ b/src/convnets/resnets.jl @@ -26,7 +26,8 @@ end function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) @assert cardinality==1 "`basicblock` only supports cardinality of 1" @assert base_width==64 "`basicblock` does not support changing base width" @@ -40,8 +41,10 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = dilation, dilation = dilation, bias = false), norm_layer(outplanes)) + attn_layer = attn_fn(outplanes; attn_args...) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, + drop_path)), activation) end expansion_factor(::typeof(basicblock)) = 1 @@ -49,7 +52,8 @@ expansion_factor(::typeof(basicblock)) = 1 function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity) + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first @@ -61,9 +65,10 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina dilation = first_dilation, groups = cardinality, bias = false), norm_layer(width)) conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) + attn_layer = attn_fn(outplanes; attn_args...) return Chain(Parallel(+, downsample, Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, - drop_path)), + attn_layer, drop_path)), activation) end expansion_factor(::typeof(bottleneck)) = 4 @@ -233,10 +238,13 @@ struct ResNet end @functor ResNet +(m::ResNet)(x) = m.layers(x) + """ ResNet(depth::Integer; pretrain = false, nclasses = 1000) Creates a ResNet model with the specified depth. +((reference)[https://arxiv.org/abs/1512.03385]) # Arguments @@ -253,11 +261,11 @@ Advanced users who want more configuration options will be better served by usin function ResNet(depth::Integer; pretrain = false, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - model = resnet(resnet_config[depth]...; nclasses) + layers = resnet(resnet_config[depth]...; nclasses) if pretrain - loadpretrain!(model, string("resnet", depth)) + loadpretrain!(layers, string("resnet", depth)) end - return model + return ResNet(layers) end struct ResNeXt @@ -265,10 +273,13 @@ struct ResNeXt end @functor ResNeXt +(m::ResNeXt)(x) = m.layers(x) + """ ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) Creates a ResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/abs/1611.05431]) # Arguments @@ -288,10 +299,73 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - model = resnet(resnet_config[depth]...; nclasses, - block_args = (; cardinality, base_width)) + layers = resnet(resnet_config[depth]...; nclasses, + block_args = (; cardinality, base_width)) + if pretrain + loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width)) + end + return ResNeXt(layers) +end + +struct SEResNet + layers::Any +end +@functor SEResNet + +(m::SEResNet)(x) = m.layers(x) + +""" + SEResNet(depth::Integer; pretrain = false, nclasses = 1000) + +Creates a SEResNet model with the specified depth. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `nclasses`: the number of output classes +""" +function SEResNet(depth::Integer; pretrain = false, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + layers = resnet(resnet_config[depth]...; nclasses, + block_args = (; attn_fn = squeeze_excite)) + if pretrain + loadpretrain!(layers, string("seresnet", depth)) + end + return SEResNet(layers) +end + +struct SEResNeXt + layers::Any +end +@functor SEResNeXt + +(m::SEResNeXt)(x) = m.layers(x) + +""" + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + +Creates a SEResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `nclasses`: the number of output classes +""" +function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + layers = resnet(resnet_config[depth]...; nclasses, + block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) if pretrain - loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width)) + loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) end - return model + return SEResNeXt(layers) end diff --git a/test/convnets.jl b/test/convnets.jl index a716e9f9d..f6993cdb7 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -3,11 +3,9 @@ @test size(model(x_256)) == (1000, 1) @test_throws ArgumentError AlexNet(pretrain = true) @test gradtest(model, x_256) + _gc() end -GC.safepoint() -GC.gc() - @testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] m = VGG(sz, batchnorm = bn) @@ -18,14 +16,10 @@ GC.gc() @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "ResNet" begin # Tests for pretrained ResNets ## TODO: find a way to port pretrained models to the new ResNet API @@ -55,17 +49,13 @@ GC.gc() m = Metalhead.resnet(block_fn, layers; drop_rates) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end end end -GC.safepoint() -GC.gc() - @testset "ResNeXt" begin @testset for depth in [50, 101, 152] @testset for cardinality in [32, 64] @@ -78,15 +68,43 @@ GC.gc() @test_throws ArgumentError ResNeXt(depth, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end end -GC.safepoint() -GC.gc() +@testset "SEResNet" begin + @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) + @test size(m(x_224)) == (1000, 1) + if string("seresnet", depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNet(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + +@testset "SEResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if string("seresnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + end +end @testset "EfficientNet" begin @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8] @@ -101,14 +119,10 @@ GC.gc() @test_throws ArgumentError EfficientNet(name, pretrain = true) end @test gradtest(m, x) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) @@ -118,11 +132,9 @@ GC.gc() @test_throws ArgumentError GoogLeNet(pretrain = true) end @test gradtest(m, x_224) + _gc() end -GC.safepoint() -GC.gc() - @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @testset "Inceptionv3" begin @@ -135,8 +147,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "Inceptionv4" begin m = Inceptionv4() @test size(m(x_299)) == (1000, 2) @@ -147,8 +158,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "InceptionResNetv2" begin m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) @@ -159,8 +169,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "Xception" begin m = Xception() @test size(m(x_299)) == (1000, 2) @@ -171,11 +180,9 @@ GC.gc() end @test gradtest(m, x_299) end + _gc() end -GC.safepoint() -GC.gc() - @testset "SqueezeNet" begin m = SqueezeNet() @test size(m(x_224)) == (1000, 1) @@ -185,15 +192,12 @@ GC.gc() @test_throws ArgumentError SqueezeNet(pretrain = true) end @test gradtest(m, x_224) + _gc() end -GC.safepoint() -GC.gc() - @testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) if (DenseNet, sz) in PRETRAINED_MODELS @test acctest(DenseNet(sz, pretrain = true)) @@ -201,18 +205,13 @@ GC.gc() @test_throws ArgumentError DenseNet(sz, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "MobileNet" verbose = true begin @testset "MobileNetv1" begin m = MobileNetv1() - @test size(m(x_224)) == (1000, 1) if MobileNetv1 in PRETRAINED_MODELS @test acctest(MobileNetv1(pretrain = true)) @@ -221,8 +220,7 @@ GC.gc() end @test gradtest(m, x_224) end - GC.safepoint() - GC.gc() + _gc() @testset "MobileNetv2" begin m = MobileNetv2() @test size(m(x_224)) == (1000, 1) @@ -233,12 +231,10 @@ GC.gc() end @test gradtest(m, x_224) end - GC.safepoint() - GC.gc() + _gc() @testset "MobileNetv3" verbose = true begin @testset for mode in [:small, :large] m = MobileNetv3(mode) - @test size(m(x_224)) == (1000, 1) if (MobileNetv3, mode) in PRETRAINED_MODELS @test acctest(MobileNetv3(mode; pretrain = true)) @@ -246,12 +242,11 @@ GC.gc() @test_throws ArgumentError MobileNetv3(mode; pretrain = true) end @test gradtest(m, x_224) + _gc() end end end -GC.safepoint() -GC.gc() @testset "ConvNeXt" verbose = true begin @testset for mode in [:small, :base, :large, :tiny, :xlarge] @@ -259,22 +254,16 @@ GC.gc() m = ConvNeXt(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end -GC.safepoint() -GC.gc() - @testset "ConvMixer" verbose = true begin @testset for mode in [:small, :base, :large] m = ConvMixer(mode) - @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end diff --git a/test/other.jl b/test/other.jl index 3c1752f3a..df97d4f5f 100644 --- a/test/other.jl +++ b/test/other.jl @@ -4,8 +4,7 @@ m = MLPMixer(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end @@ -16,8 +15,7 @@ end m = ResMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end @@ -28,8 +26,7 @@ end m = gMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end diff --git a/test/runtests.jl b/test/runtests.jl index f1a9787b9..55e416ac2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,11 @@ const PRETRAINED_MODELS = [ (ResNet, 152), ] +function _gc() + GC.safepoint() + GC.gc() +end + function gradtest(model, input) y, pb = Zygote.pullback(() -> model(input), Flux.params(model)) gs = pb(ones(Float32, size(y))) diff --git a/test/vit-based.jl b/test/vit-based.jl index 9dc348819..e889b07be 100644 --- a/test/vit-based.jl +++ b/test/vit-based.jl @@ -3,7 +3,6 @@ m = ViT(mode) @test size(m(x_256)) == (1000, 1) @test gradtest(m, x_256) - GC.safepoint() - GC.gc() + _gc() end end From e5d2295619e8b658d99c61f5c8c15dbeb30b9e80 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Mon, 4 Jul 2022 17:46:51 +0530 Subject: [PATCH 39/64] More docs, more tweaks --- src/Metalhead.jl | 4 +- src/convnets/convnext.jl | 22 ++-- src/convnets/resnets.jl | 270 ++++++++++++++++++++++++++++++++------- src/layers/Layers.jl | 6 +- src/layers/attention.jl | 4 +- src/layers/classifier.jl | 13 ++ src/layers/conv.jl | 43 ++----- src/layers/embeddings.jl | 2 +- src/layers/normalise.jl | 4 +- src/layers/pool.jl | 35 +++-- src/layers/selayers.jl | 1 - src/other/mlpmixer.jl | 6 +- src/vit-based/vit.jl | 2 +- 13 files changed, 297 insertions(+), 115 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index a4dd73785..e60eff405 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -49,9 +49,9 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt, +for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :SEResNet, :SEResNeXt, + :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index f3da6dbf3..6ced7eeb9 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -4,7 +4,7 @@ Creates a single block of ConvNeXt. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `planes`: number of input channels. - `drop_path_rate`: Stochastic depth rate. @@ -27,7 +27,7 @@ end Creates the layers for a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `inchannels`: number of input channels. - `depths`: list with configuration for depth of each block @@ -39,32 +39,29 @@ Creates the layers for a ConvNeXt model. """ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert length(depths)==length(planes) "`planes` should have exactly one value for each block" - + @assert length(depths) == length(planes) + "`planes` should have exactly one value for each block" downsample_layers = [] stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4), - ChannelLayerNorm(planes[1]; ϵ = 1.0f-6)) + ChannelLayerNorm(planes[1])) push!(downsample_layers, stem) for m in 1:(length(depths) - 1) - downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6), + downsample_layer = Chain(ChannelLayerNorm(planes[m]), Conv((2, 2), planes[m] => planes[m + 1]; stride = 2)) push!(downsample_layers, downsample_layer) end - stages = [] dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths)) cur = 0 - for i in 1:length(depths) + for i in eachindex(depths) push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]]) cur += depths[i] end - backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages)))) head = Chain(GlobalMeanPool(), MLUtils.flatten, LayerNorm(planes[end]), Dense(planes[end], nclasses)) - return Chain(Chain(backbone), head) end @@ -90,7 +87,7 @@ end Creates a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `inchannels`: The number of channels in the input. - `drop_path_rate`: Stochastic depth rate. @@ -101,7 +98,8 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))" + @assert mode in keys(convnext_configs) + "`size` must be one of $(collect(keys(convnext_configs)))" depths = convnext_configs[mode][:depths] planes = convnext_configs[mode][:planes] layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses) diff --git a/src/convnets/resnets.jl b/src/convnets/resnets.jl index 73f070617..f46ac29ad 100644 --- a/src/convnets/resnets.jl +++ b/src/convnets/resnets.jl @@ -1,31 +1,50 @@ -function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(outplanes)) -end +# resnet.jl +## It is recommended to check out the user's guide for more information +## regarding the use of these functions. + +### ResNet blocks +## These functions return a block to be used inside of a ResNet model. +## The individual arguments are explained in the documentation of the functions. +## Note that for these blocks to be used by the `_make_blocks` function, they must define +## a dispatch `expansion(::typeof(fn))` that returns the expansion factor of the block +## (i.e. the multiplicative factor by which the number of channels in the input is increased). +## The `_make_blocks` function will then call the `expansion` function to determine the +## expansion factor of each block and use this to construct the stages of the model. -function downsample_avg(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = MeanPool((2, 2); stride = avg_stride, pad) - end - return Chain(pool, - Conv((1, 1), inplanes => outplanes; bias = false), - norm_layer(outplanes)) -end +""" + basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) + +Creates a basic ResNet block. +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `downsample`: the downsampling function to use + - `cardinality`: redundant, kept for compatibility with `bottleneck`. + - `base_width`: redundant, kept for compatibility with `bottleneck`. + - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + convolution. + - `dilation`: the dilation of the second convolution. + - `first_dilation`: the dilation of the first convolution. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. + - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the + attention function. +""" function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) @@ -33,7 +52,6 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina @assert base_width==64 "`basicblock` does not support changing base width" first_planes = planes ÷ reduce_first outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, dilation = first_dilation, bias = false), norm_layer(first_planes)) @@ -49,16 +67,46 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina end expansion_factor(::typeof(basicblock)) = 1 +""" + bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = dilation, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) + +Creates a bottleneck ResNet block. + +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `downsample`: the downsampling function to use + - `cardinality`: the number of groups in the convolution. + - `base_width`: the number of output feature maps for each convolutional group. + - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + convolution. + - `dilation`: redundant, kept for compatibility with `basicblock`. + - `first_dilation`: the dilation of the 3x3 convolution. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. + - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the + attention function. +""" function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, + first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first outplanes = planes * expansion - first_dilation = !isnothing(first_dilation) ? first_dilation : dilation conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), norm_layer(first_planes, activation)) conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, @@ -77,17 +125,17 @@ expansion_factor(::typeof(bottleneck)) = 4 resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) -Builds a stem to be used in a ResNet model. See the `stem` argument of `resnet` for details +Builds a stem to be used in a ResNet model. See the `stem` argument of [`resnet`](#) for details on how to use this function. -# Arguments: +# Arguments - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + `:default`: Builds a stem based on the default ResNet stem, which consists of a single 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling layer with stride 2. - + `:deep`: This borrows ideas from other papers (InceptionResNet-v2 for one) in using a + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each one. This is followed by a 3x3 max pooling layer with stride 2. + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second @@ -137,13 +185,62 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = return Chain(conv1, bn1, stempool), inplanes end +### Downsampling layers +## These will almost never be used directly. They are used by the `_make_blocks` function to +## build the downsampling layers. In most cases, these defaults will not need to be changed. +## If you wish to write your own ResNet model using the `_make_blocks` function, you can use +## this function to build the downsampling layers. + +# Downsample layer using convolutions. +function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, + norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size + dilation = kernel_size[1] > 1 ? dilation : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, + dilation, bias = false), + norm_layer(outplanes)) +end + +# Downsample layer using max pooling +function downsample_pool(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, + norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = MeanPool((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), inplanes => outplanes; bias = false), + norm_layer(outplanes)) +end + +""" + downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), + stride = 1, dilation = 1, norm_layer = BatchNorm) + +Wrapper function that makes it easier to build a downsample block inside a ResNet model. +This function is almost never used directly or customised by the user. + +# Arguments + + - `downsample_fn`: The function to use for downsampling in skip connections. Recommended usage + is passing in either `downsample_conv` or `downsample_pool`. + - `inplanes`: The number of input feature maps. + - `planes`: The number of output feature maps. + - `expansion`: The expansion factor of the block. + - `kernel_size`: The size of the convolutional kernel. + - `stride`: The stride of the convolutional layer. + - `dilation`: The dilation of the convolutional layer. + - `norm_layer`: The normalisation layer to be used. +""" function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), - stride = 1, dilation = 1, first_dilation = dilation, - norm_layer = BatchNorm) + stride = 1, dilation = 1, norm_layer = BatchNorm) if stride != 1 || inplanes != planes * expansion downsample = downsample_fn(kernel_size, inplanes, planes * expansion; - stride, dilation, first_dilation, - norm_layer) + stride, dilation, norm_layer) else downsample = identity end @@ -166,6 +263,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # Stochastic depth linear decay rule (DropPath) dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), sum(block_repeats)) + # Construct each stage for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, block_repeats, _drop_blocks(get(drop_rates, @@ -181,10 +279,11 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride end # Downsample block; either a (default) convolution-based block or a pooling-based block downsample = downsample_block(downsample_fn, inplanes, planes, expansion; - stride, dilation, first_dilation = dilation) + stride, dilation) # Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks + # Different behaviour for the first block of each stage downsample = block_idx == 1 ? downsample : identity stride = block_idx == 1 ? stride : 1 push!(blocks, @@ -209,7 +308,62 @@ function _drop_blocks(drop_block_prob = 0.0) ] end -function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32, +""" + resnet(block_type, layers; inchannels = 3, nclasses = 1000, output_stride = 32, + stem = first(resnet_stem(; inchannels)), inplanes = 64, + downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), + drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = 0.0), + classifier_args::NamedTuple = NamedTuple()) + +This function creates the layers for many ResNet-like models. + +!!! note + + If you are an end-user trying to use ResNet-like models, you should consider [`ResNet`](#) + and similar higher-level functions instead. This version is significantly more customisable + at the cost of being more complicated. + +# Arguments + + - `block_fn`: The type of block to use inside the ResNet model. Must be either `:basicblock`, + which is the standard ResNet block, or `:bottleneck`, which is the ResNet block with a + bottleneck structure. See the [paper](https://arxiv.org/abs/1512.03385) for more details. + + - `layers`: A list of integers specifying the number of blocks in each stage. For example, + `[3, 4, 6, 3]` would mean that the network would have 4 stages, with 3, 4, 6 and 3 blocks in + each. + - `nclasses`: The number of output classes. + - `inchannels`: The number of input channels. + - `output_stride`: The total stride of the network i.e. the amount by which the input is + downsampled throughout the network. This is used to determine the output size from the + backbone of the network. Must be one of `[8, 16, 32]`. + - `stem`: A constructed ResNet stem, passed in to be used in the model. `inplanes` should be + set to the number of output channels from this stem. Metalhead provides an in-built + function for creating a stem (see [`resnet_stem`](#)) but you can also create your + own (although this is not usually necessary). + - `inplanes`: The number of output channels from the stem. + - `downsample_type`: The type of downsampling to use. Either `:conv` or `:pool`. The former + uses a traditional convolution-based downsampling, while the latter is an + average-pooling-based downsampling that was suggested in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) + paper. + - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed + to the block function. For more information regarding valid arguments, see + the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). + - `drop_rates`: A `NamedTuple` that can may define none, some or all of the following: + + + `dropout_rate`: The rate of dropout to be used in the classifier head. + + `drop_path_rate`: Stochastic depth implemented using [`DropPath`](#). + + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). + - `classifier_args`: A `NamedTuple` that may define none, some or all of the following: + + + `pool_type`: The type of pooling to use in the classifier head. Uses + [`SelectAdaptivePool`](#) to select the pooling function. See its + documentation for more information. + + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a + `Dense` layer. +""" +function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, @@ -241,7 +395,7 @@ end (m::ResNet)(x) = m.layers(x) """ - ResNet(depth::Integer; pretrain = false, nclasses = 1000) + ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) Creates a ResNet model with the specified depth. ((reference)[https://arxiv.org/abs/1512.03385]) @@ -250,6 +404,7 @@ Creates a ResNet model with the specified depth. - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes !!! warning @@ -258,10 +413,10 @@ Creates a ResNet model with the specified depth. Advanced users who want more configuration options will be better served by using [`resnet`](#). """ -function ResNet(depth::Integer; pretrain = false, nclasses = 1000) +function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses) + layers = resnet(resnet_config[depth]...; inchannels, nclasses) if pretrain loadpretrain!(layers, string("resnet", depth)) end @@ -276,7 +431,8 @@ end (m::ResNeXt)(x) = m.layers(x) """ - ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) Creates a ResNeXt model with the specified depth, cardinality, and base width. ((reference)[https://arxiv.org/abs/1611.05431]) @@ -287,6 +443,7 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels. - `nclasses`: the number of output classes !!! warning @@ -295,11 +452,11 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. Advanced users who want more configuration options will be better served by using [`resnet`](#). """ -function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - nclasses = 1000) +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses, + layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; cardinality, base_width)) if pretrain loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width)) @@ -315,7 +472,7 @@ end (m::SEResNet)(x) = m.layers(x) """ - SEResNet(depth::Integer; pretrain = false, nclasses = 1000) + SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) Creates a SEResNet model with the specified depth. ((reference)[https://arxiv.org/pdf/1709.01507.pdf]) @@ -324,12 +481,19 @@ Creates a SEResNet model with the specified depth. - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: the number of input channels. - `nclasses`: the number of output classes + +!!! warning + + `SEResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). """ -function SEResNet(depth::Integer; pretrain = false, nclasses = 1000) +function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses, + layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; attn_fn = squeeze_excite)) if pretrain loadpretrain!(layers, string("seresnet", depth)) @@ -345,7 +509,8 @@ end (m::SEResNeXt)(x) = m.layers(x) """ - SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000) + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) Creates a SEResNeXt model with the specified depth, cardinality, and base width. ((reference)[https://arxiv.org/pdf/1709.01507.pdf]) @@ -356,13 +521,20 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels - `nclasses`: the number of output classes + +!!! warning + + `SEResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). """ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - nclasses = 1000) + inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; nclasses, + layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) if pretrain loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 41a98843e..a86361143 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -25,8 +25,7 @@ include("normalise.jl") export prenorm, ChannelLayerNorm include("conv.jl") -export conv_bn, depthwise_sep_conv_bn, invertedresidual -skip_identity, skip_projection +export conv_bn, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection include("drop.jl") export DropPath, DropBlock @@ -38,7 +37,6 @@ include("classifier.jl") export create_classifier include("pool.jl") -export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool -SelectAdaptivePool +export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool, SelectAdaptivePool end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 3cefe7c0d..7d8ee776d 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -3,7 +3,7 @@ Multi-head self-attention layer. -# Arguments: +# Arguments - `nheads`: Number of heads - `qkv_layer`: layer to be used for getting the query, key and value @@ -22,7 +22,7 @@ end Multi-head self-attention layer. -# Arguments: +# Arguments - `planes`: number of input channels - `nheads`: number of heads diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl index e2a1fe75c..0e9ba02d1 100644 --- a/src/layers/classifier.jl +++ b/src/layers/classifier.jl @@ -1,3 +1,16 @@ +""" + create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) + +Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `pool_type`: the type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, + `:catmeanmax` or `:identity`. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. +""" function create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling if pool_type == :identity diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e56967aef..7605a6cd1 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,8 +1,7 @@ """ conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, - groups = 1, [bias, weight, init], initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1) + rev = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, + groups = 1, [bias, weight, init]) Create a convolution + batch normalization pair with activation. @@ -22,13 +21,9 @@ Create a convolution + batch normalization pair with activation. - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) - - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) - - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) """ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, - initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1.0f-5, momentum = 1.0f-1, - kwargs...) + rev = false, preact = false, use_bn = true, kwargs...) if !use_bn (preact || rev) ? throw("preact only supported with `use_bn = true`") : return [Conv(kernelsize, inplanes => outplanes, activation; kwargs...)] @@ -48,17 +43,14 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...)) push!(layers, - BatchNorm(Int(bnplanes), activations.bn; - initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum)) + BatchNorm(Int(bnplanes), activations.bn)) return rev ? reverse(layers) : layers end """ depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; rev = false, use_bn = (true, true), - stride = 1, pad = 0, dilation = 1, [bias, weight, init], - initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1) + stride = 1, pad = 0, dilation = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: @@ -82,21 +74,13 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) - - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) - - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) """ function depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; rev = false, use_bn = (true, true), - initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1, stride = 1, kwargs...) return vcat(conv_bn(kernelsize, inplanes, inplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, - ϵ = ϵ, momentum = momentum, use_bn = use_bn[1], - stride = stride, groups = Int(inplanes), kwargs...), - conv_bn((1, 1), inplanes, outplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, use_bn = use_bn[2], - ϵ = ϵ, momentum = momentum)) + rev, use_bn = use_bn[1], stride, groups = Int(inplanes), kwargs...), + conv_bn((1, 1), inplanes, outplanes, activation; rev, use_bn = use_bn[2])) end """ @@ -105,10 +89,10 @@ end Create a skip projection ([reference](https://arxiv.org/abs/1512.03385v1)). -# Arguments: +# Arguments - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps - `downsample`: set to `true` to downsample the input """ function skip_projection(inplanes, outplanes, downsample = false) @@ -124,7 +108,7 @@ end Create a identity projection ([reference](https://arxiv.org/abs/1512.03385v1)). -# Arguments: +# Arguments - `inplanes`: the number of input feature maps - `outplanes`: the number of output feature maps @@ -153,15 +137,14 @@ Create a basic inverted residual block for MobileNet variants # Arguments - - `kernel_size`: The kernel size of the convolutional layers - - `inplanes`: The number of input feature maps + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps - `hidden_planes`: The number of feature maps in the hidden layer - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)). - Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; stride, reduction = nothing) diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index 66f25d1c0..3e85f18d9 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -8,7 +8,7 @@ _flatten_spatial(x) = permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1 Patch embedding layer used by many vision transformer-like models to split the input image into patches. -# Arguments: +# Arguments - `imsize`: the size of the input image - `inchannels`: the number of channels in the input. diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c767bd1e0..e71634e22 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,7 +2,7 @@ prenorm(planes, fn) = Chain(LayerNorm(planes), fn) """ - ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1f-5) + ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) A variant of LayerNorm where the input is normalised along the channel dimension. The input is expected to have channel dimension with size @@ -19,7 +19,7 @@ end @functor ChannelLayerNorm -function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) +function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) diag = Flux.Scale(1, 1, sz, λ) return ChannelLayerNorm(diag, ϵ) end diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 4ffe298e3..36a08a8da 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,21 +1,40 @@ -function AdaptiveMeanMaxPool(output_size = (1, 1)) - return 0.5 * Parallel(.+, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) -end +""" + AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) + +A type of adaptive pooling layer which uses both mean and max pooling and combines them to +produce a single output. Note that this is equivalent to +`Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))` + +# Arguments -function AdaptiveCatMeanMaxPool(output_size = (1, 1)) - return Parallel(cat_channels, AdaptiveAvgMaxPool(output_size), - AdaptiveMaxPool(output_size)) + - `output_size`: The size of the output after pooling. + - `connection`: The connection type to use. +""" +function AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) + return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) end +""" + SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) + +Adaptive pooling factory function. + +# Arguments + + - `output_size`: The size of the output after pooling. + - `pool_type`: The type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, + `:catmeanmax` or `:identity`. + - `flatten`: Whether to flatten the output from the pooling layer. +""" function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) if pool_type == :mean pool = AdaptiveMeanPool(output_size) elseif pool_type == :max pool = AdaptiveMaxPool(output_size) elseif pool_type == :meanmax - pool = AdaptiveMeanMaxPool(output_size) + pool = 0.5f0 * AdaptiveMeanMaxPool(output_size) elseif pool_type == :catmeanmax - pool = AdaptiveCatMeanMaxPool(output_size) + pool = AdaptiveMeanMaxPool(output_size; connection = cat_channels) elseif pool_type == :identity pool = identity else diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 7f1a76d59..6d86947c9 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -14,7 +14,6 @@ Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer - Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, activation = relu, gate_activation = sigmoid, diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 5083b228e..aab7cba4e 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -5,7 +5,7 @@ Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) -# Arguments: +# Arguments - `planes`: the number of planes in the block - `npatches`: the number of patches of the input @@ -55,8 +55,8 @@ Creates a model with the MLPMixer architecture. not specified. """ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, - norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0.0, + norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), + embedplanes = 512, drop_path_rate = 0.0, depth = 12, nclasses = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) dp_rates = LinRange{Float32}(0.0, drop_path_rate, depth) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index a06ce6886..b3a7e167c 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -50,7 +50,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] - "Pool type must be either :class (class token) or :mean (mean pooling)" + "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), From 4a91fc479d7f70c1d5d9b26e1e0b13f2a4d22857 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 8 Jul 2022 21:54:48 +0530 Subject: [PATCH 40/64] More aggressive GC Co-authored-by: Brian Chen --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 55e416ac2..1a8c77f25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ const PRETRAINED_MODELS = [ function _gc() GC.safepoint() - GC.gc() + GC.gc(true) end function gradtest(model, input) From cf538bb283a538a4b9867e58c98c5621dd684f2c Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 9 Jul 2022 09:35:31 +0530 Subject: [PATCH 41/64] Tweaks don't stop Neither does formatting, unfortunately. Also refactor `classifier` to separate out FC-layer creation and pooling --- src/convnets/inception.jl | 8 +-- src/convnets/resnets.jl | 78 ++++++++++++++-------------- src/layers/Layers.jl | 10 ++-- src/layers/classifier.jl | 25 --------- src/layers/{mlp-linear.jl => mlp.jl} | 33 ++++++------ src/layers/pool.jl | 34 +----------- src/layers/scale.jl | 24 +++++++++ src/other/mlpmixer.jl | 4 +- src/utilities.jl | 10 ---- src/vit-based/vit.jl | 2 +- 10 files changed, 95 insertions(+), 133 deletions(-) delete mode 100644 src/layers/classifier.jl rename src/layers/{mlp-linear.jl => mlp.jl} (79%) create mode 100644 src/layers/scale.jl diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index 156362cf3..5823e9737 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -425,7 +425,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -458,7 +458,7 @@ function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000 end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -542,7 +542,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, dropout_rate =0.0, nclasses = 1000) + xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -573,7 +573,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, dropout_rate =0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) diff --git a/src/convnets/resnets.jl b/src/convnets/resnets.jl index f46ac29ad..a70fc7f07 100644 --- a/src/convnets/resnets.jl +++ b/src/convnets/resnets.jl @@ -1,6 +1,5 @@ # resnet.jl -## It is recommended to check out the user's guide for more information -## regarding the use of these functions. +## It is recommended to check out the user guide for more information. ### ResNet blocks ## These functions return a block to be used inside of a ResNet model. @@ -12,10 +11,9 @@ ## expansion factor of each block and use this to construct the stages of the model. """ - basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = nothing, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, + basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, + dilation = 1, first_dilation = dilation, activation = relu, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) Creates a basic ResNet block. @@ -26,8 +24,6 @@ Creates a basic ResNet block. - `planes`: number of feature maps for the block - `stride`: the stride of the block - `downsample`: the downsampling function to use - - `cardinality`: redundant, kept for compatibility with `bottleneck`. - - `base_width`: redundant, kept for compatibility with `bottleneck`. - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first convolution. - `dilation`: the dilation of the second convolution. @@ -42,14 +38,11 @@ Creates a basic ResNet block. - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the attention function. """ -function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = dilation, activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, +function basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, + dilation = 1, first_dilation = dilation, activation = relu, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) - @assert cardinality==1 "`basicblock` only supports cardinality of 1" - @assert base_width==64 "`basicblock` does not support changing base width" first_planes = planes ÷ reduce_first outplanes = planes * expansion conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, @@ -69,8 +62,8 @@ expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = dilation, activation = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, first_dilation = 1, + activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -86,7 +79,6 @@ Creates a bottleneck ResNet block. - `base_width`: the number of output feature maps for each convolutional group. - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `dilation`: redundant, kept for compatibility with `basicblock`. - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. @@ -99,8 +91,8 @@ Creates a bottleneck ResNet block. attention function. """ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, dilation = 1, - first_dilation = dilation, activation = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, first_dilation = 1, + activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) @@ -263,12 +255,12 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # Stochastic depth linear decay rule (DropPath) dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), sum(block_repeats)) - # Construct each stage - for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, - block_repeats, - _drop_blocks(get(drop_rates, - :drop_block_rate, - 0)))) + # DropBlock rate + dbr = get(drop_rates, :drop_block_rate, 0) + ## Construct each stage + for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr))) + # Number of planes in each stage, number of blocks in each stage, and the drop block rate + planes, num_blocks, drop_block = itr # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -280,7 +272,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # Downsample block; either a (default) convolution-based block or a pooling-based block downsample = downsample_block(downsample_fn, inplanes, planes, expansion; stride, dilation) - # Construct the blocks for each stage + ## Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks # Different behaviour for the first block of each stage @@ -309,14 +301,16 @@ function _drop_blocks(drop_block_prob = 0.0) end """ - resnet(block_type, layers; inchannels = 3, nclasses = 1000, output_stride = 32, + resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), - classifier_args::NamedTuple = NamedTuple()) + classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), + use_conv = false)) -This function creates the layers for many ResNet-like models. +This function creates the layers for many ResNet-like models. See the user guide for more +information. !!! note @@ -350,16 +344,14 @@ This function creates the layers for many ResNet-like models. - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed to the block function. For more information regarding valid arguments, see the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). - - `drop_rates`: A `NamedTuple` that can may define none, some or all of the following: + - `drop_rates`: A `NamedTuple` that may define none, some or all of the following: + `dropout_rate`: The rate of dropout to be used in the classifier head. + `drop_path_rate`: Stochastic depth implemented using [`DropPath`](#). + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). - - `classifier_args`: A `NamedTuple` that may define none, some or all of the following: + - `classifier_args`: A `NamedTuple` that **must** specify the following arguments: - + `pool_type`: The type of pooling to use in the classifier head. Uses - [`SelectAdaptivePool`](#) to select the pooling function. See its - documentation for more information. + + `pool_layer`: The adaptive pooling layer to use in the classifier head. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ @@ -368,15 +360,25 @@ function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), - classifier_args::NamedTuple = NamedTuple()) - # Feature Blocks + classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), + use_conv = false)) + ## Feature Blocks channels = [64, 128, 256, 512] stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fn, drop_rates, block_args) - # Head (Pooling and Classifier) + ## Classifier head expansion = expansion_factor(block_fn) num_features = 512 * expansion - global_pool, fc = create_classifier(num_features, nclasses; classifier_args...) + pool_layer, use_conv = classifier_args + # Pooling + if pool_layer === identity + @assert use_conv + "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" + end + flatten_in_pool = !use_conv && pool_layer !== identity + global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer + # Fully-connected layer + fc = create_fc(num_features, nclasses; use_conv) classifier = Chain(global_pool, Dropout(get(drop_rates, :dropout_rate, 0)), fc) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index a86361143..2c4b11e5a 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -18,8 +18,8 @@ export MHAttention include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens -include("mlp-linear.jl") -export mlp_block, gated_mlp_block, LayerScale +include("mlp.jl") +export mlp_block, gated_mlp_block, create_fc include("normalise.jl") export prenorm, ChannelLayerNorm @@ -33,10 +33,10 @@ export DropPath, DropBlock include("selayers.jl") export squeeze_excite, effective_squeeze_excite -include("classifier.jl") -export create_classifier +include("scale.jl") +export LayerScale, inputscale include("pool.jl") -export AdaptiveMeanMaxPool, AdaptiveCatMeanMaxPool, SelectAdaptivePool +export AdaptiveMeanMaxPool end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl deleted file mode 100644 index 0e9ba02d1..000000000 --- a/src/layers/classifier.jl +++ /dev/null @@ -1,25 +0,0 @@ -""" - create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) - -Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. - -# Arguments - - - `inplanes`: number of input feature maps - - `nclasses`: number of output classes - - `pool_type`: the type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, - `:catmeanmax` or `:identity`. - - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. -""" -function create_classifier(inplanes, nclasses; pool_type = :mean, use_conv = false) - flatten_in_pool = !use_conv # flatten when we use a Dense layer after pooling - if pool_type == :identity - @assert use_conv - "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" - flatten_in_pool = false # disable flattening if pooling is pass-through (no pooling) - end - global_pool = SelectAdaptivePool(; pool_type, flatten = flatten_in_pool) - fc = use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : - Dense(inplanes => nclasses; bias = true) - return global_pool, fc -end diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp.jl similarity index 79% rename from src/layers/mlp-linear.jl rename to src/layers/mlp.jl index 8cca1e266..f72520451 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp.jl @@ -1,21 +1,6 @@ -""" - LayerScale(λ, planes::Integer) - -Creates a `Flux.Scale` layer that performs "`LayerScale`" -([reference](https://arxiv.org/abs/2103.17239)). - -# Arguments - - - `planes`: Size of channel dimension in the input. - - `λ`: initialisation value for the learnable diagonal matrix. -""" -function LayerScale(planes::Integer, λ) - return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity -end - """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout_rate =0., activation = gelu) + dropout_rate = 0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -60,3 +45,19 @@ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) + +""" + create_fc(inplanes, nclasses; use_conv = false) + +Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. +""" +function create_fc(inplanes, nclasses; use_conv = false) + return use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : + Dense(inplanes => nclasses; bias = true) +end diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 36a08a8da..0a74a24c0 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,5 +1,5 @@ """ - AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) + AdaptiveMeanMaxPool(output_size = (1, 1); connection = +) A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to @@ -10,36 +10,6 @@ produce a single output. Note that this is equivalent to - `output_size`: The size of the output after pooling. - `connection`: The connection type to use. """ -function AdaptiveMeanMaxPool(output_size = (1, 1); connection = .+) +function AdaptiveMeanMaxPool(output_size = (1, 1); connection = +) return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) end - -""" - SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) - -Adaptive pooling factory function. - -# Arguments - - - `output_size`: The size of the output after pooling. - - `pool_type`: The type of adaptive pooling to use. One of `:mean`, `:max`, `:meanmax`, - `:catmeanmax` or `:identity`. - - `flatten`: Whether to flatten the output from the pooling layer. -""" -function SelectAdaptivePool(output_size = (1, 1); pool_type = :mean, flatten = false) - if pool_type == :mean - pool = AdaptiveMeanPool(output_size) - elseif pool_type == :max - pool = AdaptiveMaxPool(output_size) - elseif pool_type == :meanmax - pool = 0.5f0 * AdaptiveMeanMaxPool(output_size) - elseif pool_type == :catmeanmax - pool = AdaptiveMeanMaxPool(output_size; connection = cat_channels) - elseif pool_type == :identity - pool = identity - else - throw(AssertionError("Invalid pool type: $pool_type")) - end - flatten_fn = flatten ? MLUtils.flatten : identity - return Chain(pool, flatten_fn) -end diff --git a/src/layers/scale.jl b/src/layers/scale.jl new file mode 100644 index 000000000..cd55fc97c --- /dev/null +++ b/src/layers/scale.jl @@ -0,0 +1,24 @@ +""" + inputscale(λ; activation = identity) + +Scale the input by a scalar `λ` and applies an activation function to it. +Equivalent to `activation.(λ .* x)`. +""" +inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) +_input_scale(x, λ, activation) = activation.(λ .* x) +_input_scale(x, λ, ::typeof(identity)) = λ .* x + +""" + LayerScale(λ, planes::Integer) + +Creates a `Flux.Scale` layer that performs "`LayerScale`" +([reference](https://arxiv.org/abs/2103.17239)). + +# Arguments + + - `planes`: Size of channel dimension in the input. + - `λ`: initialisation value for the learnable diagonal matrix. +""" +function LayerScale(planes::Integer, λ) + return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity +end diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index aab7cba4e..48f1efd8c 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout_rate =0., drop_path_rate = 0., activation = gelu) + dropout_rate = 0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -115,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; dropout_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. diff --git a/src/utilities.jl b/src/utilities.jl index 930cc621a..9c29350bd 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -18,16 +18,6 @@ Convenient reduction operator for use with `Parallel`. """ cat_channels(xy...) = cat(xy...; dims = Val(3)) -""" - inputscale(λ; activation = identity) - -Scale the input by a scalar `λ` and applies an activation function to it. -Equivalent to `activation.(λ .* x)`. -""" -inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) -_input_scale(x, λ, activation) = activation.(λ .* x) -_input_scale(x, λ, ::typeof(identity)) = λ .* x - """ swapdims(perm) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index b3a7e167c..856b64697 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate =0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). From 5be45efea79204a6519e637e5adc031b015b06d6 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 9 Jul 2022 18:06:28 +0530 Subject: [PATCH 42/64] Reorganisation and formatting It really does never stop Co-Authored-By: Kyle Daruwalla --- src/Metalhead.jl | 6 +- src/convnets/{resnets.jl => resnets/core.jl} | 172 +------------------ src/convnets/resnets/resnet.jl | 35 ++++ src/convnets/resnets/resnext.jl | 40 +++++ src/convnets/resnets/seresnet.jl | 77 +++++++++ test/convnets.jl | 6 +- 6 files changed, 169 insertions(+), 167 deletions(-) rename src/convnets/{resnets.jl => resnets/core.jl} (77%) create mode 100644 src/convnets/resnets/resnet.jl create mode 100644 src/convnets/resnets/resnext.jl create mode 100644 src/convnets/resnets/seresnet.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index e60eff405..e88279270 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,13 +22,17 @@ include("convnets/alexnet.jl") include("convnets/vgg.jl") include("convnets/inception.jl") include("convnets/googlenet.jl") -include("convnets/resnets.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") +## ResNets +include("convnets/resnets/core.jl") +include("convnets/resnets/resnet.jl") +include("convnets/resnets/resnext.jl") +include("convnets/resnets/seresnet.jl") # Other models include("other/mlpmixer.jl") diff --git a/src/convnets/resnets.jl b/src/convnets/resnets/core.jl similarity index 77% rename from src/convnets/resnets.jl rename to src/convnets/resnets/core.jl index a70fc7f07..c53575c02 100644 --- a/src/convnets/resnets.jl +++ b/src/convnets/resnets/core.jl @@ -1,4 +1,3 @@ -# resnet.jl ## It is recommended to check out the user guide for more information. ### ResNet blocks @@ -11,7 +10,7 @@ ## expansion factor of each block and use this to construct the stages of the model. """ - basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, + basicblock(inplanes, planes; stride = 1, downsample = identity, reduction_factor = 1, dilation = 1, first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -24,7 +23,7 @@ Creates a basic ResNet block. - `planes`: number of feature maps for the block - `stride`: the stride of the block - `downsample`: the downsampling function to use - - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - `dilation`: the dilation of the second convolution. - `first_dilation`: the dilation of the first convolution. @@ -38,12 +37,13 @@ Creates a basic ResNet block. - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the attention function. """ -function basicblock(inplanes, planes; stride = 1, downsample = identity, reduce_first = 1, +function basicblock(inplanes, planes; stride = 1, downsample = identity, + reduction_factor = 1, dilation = 1, first_dilation = dilation, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) - first_planes = planes ÷ reduce_first + first_planes = planes ÷ reduction_factor outplanes = planes * expansion conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, dilation = first_dilation, bias = false), @@ -62,7 +62,7 @@ expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, first_dilation = 1, + base_width = 64, reduction_factor = 1, first_dilation = 1, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -77,7 +77,7 @@ Creates a bottleneck ResNet block. - `downsample`: the downsampling function to use - `cardinality`: the number of groups in the convolution. - `base_width`: the number of output feature maps for each convolutional group. - - `reduce_first`: the reduction factor that the input feature maps are reduced by before the first + - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. @@ -91,13 +91,13 @@ Creates a bottleneck ResNet block. attention function. """ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduce_first = 1, first_dilation = 1, + base_width = 64, reduction_factor = 1, first_dilation = 1, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality - first_planes = width ÷ reduce_first + first_planes = width ÷ reduction_factor outplanes = planes * expansion conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), norm_layer(first_planes, activation)) @@ -389,157 +389,3 @@ const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) -struct ResNet - layers::Any -end -@functor ResNet - -(m::ResNet)(x) = m.layers(x) - -""" - ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - -Creates a ResNet model with the specified depth. -((reference)[https://arxiv.org/abs/1512.03385]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `inchannels`: The number of input channels. - - `nclasses`: the number of output classes - -!!! warning - - `ResNet` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("resnet", depth)) - end - return ResNet(layers) -end - -struct ResNeXt - layers::Any -end -@functor ResNeXt - -(m::ResNeXt)(x) = m.layers(x) - -""" - ResNeXt(depth::Integer; pretrain = false, cardinality = 32, - base_width = 4, inchannels = 3, nclasses = 1000) - -Creates a ResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/abs/1611.05431]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - - `base_width`: the number of feature maps in each group. - - `inchannels`: the number of input channels. - - `nclasses`: the number of output classes - -!!! warning - - `ResNeXt` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, - base_width = 4, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width)) - if pretrain - loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width)) - end - return ResNeXt(layers) -end - -struct SEResNet - layers::Any -end -@functor SEResNet - -(m::SEResNet)(x) = m.layers(x) - -""" - SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - -Creates a SEResNet model with the specified depth. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `inchannels`: the number of input channels. - - `nclasses`: the number of output classes - -!!! warning - - `SEResNet` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; attn_fn = squeeze_excite)) - if pretrain - loadpretrain!(layers, string("seresnet", depth)) - end - return SEResNet(layers) -end - -struct SEResNeXt - layers::Any -end -@functor SEResNeXt - -(m::SEResNeXt)(x) = m.layers(x) - -""" - SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - inchannels = 3, nclasses = 1000) - -Creates a SEResNeXt model with the specified depth, cardinality, and base width. -((reference)[https://arxiv.org/pdf/1709.01507.pdf]) - -# Arguments - - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - - `base_width`: the number of feature maps in each group. - - `inchannels`: the number of input channels - - `nclasses`: the number of output classes - -!!! warning - - `SEResNeXt` does not currently support pretrained weights. - -Advanced users who want more configuration options will be better served by using [`resnet`](#). -""" -function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, - inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) - if pretrain - loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) - end - return SEResNeXt(layers) -end diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl new file mode 100644 index 000000000..cd26e69a3 --- /dev/null +++ b/src/convnets/resnets/resnet.jl @@ -0,0 +1,35 @@ +""" + ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a ResNet model with the specified depth. +((reference)[https://arxiv.org/abs/1512.03385]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `ResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct ResNet + layers::Any +end +@functor ResNet + +(m::ResNet)(x) = m.layers(x) + +function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("ResNet", depth)) + end + return ResNet(layers) +end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl new file mode 100644 index 000000000..1fa00a7b0 --- /dev/null +++ b/src/convnets/resnets/resnext.jl @@ -0,0 +1,40 @@ +""" + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) + +Creates a ResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/abs/1611.05431]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `ResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct ResNeXt + layers::Any +end +@functor ResNeXt + +(m::ResNeXt)(x) = m.layers(x) + +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + block_args = (; cardinality, base_width)) + if pretrain + loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) + end + return ResNeXt(layers) +end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl new file mode 100644 index 000000000..58c6a0607 --- /dev/null +++ b/src/convnets/resnets/seresnet.jl @@ -0,0 +1,77 @@ +""" + SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a SEResNet model with the specified depth. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `SEResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct SEResNet + layers::Any +end +@functor SEResNet + +(m::SEResNet)(x) = m.layers(x) + +function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + @assert depth in [18, 34, 50, 101, 152] + "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + block_args = (; attn_fn = squeeze_excite)) + if pretrain + loadpretrain!(layers, string("SEResNet", depth)) + end + return SEResNet(layers) +end + +""" + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) + +Creates a SEResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels + - `nclasses`: the number of output classes + +!!! warning + + `SEResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct SEResNeXt + layers::Any +end +@functor SEResNeXt + +(m::SEResNeXt)(x) = m.layers(x) + +function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) + if pretrain + loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) + end + return SEResNeXt(layers) +end diff --git a/test/convnets.jl b/test/convnets.jl index f6993cdb7..f99b204dc 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -62,7 +62,7 @@ end @testset for base_width in [4, 8] m = ResNeXt(depth; cardinality, base_width) @test size(m(x_224)) == (1000, 1) - if string("resnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS @test acctest(ResNeXt(depth, pretrain = true)) else @test_throws ArgumentError ResNeXt(depth, pretrain = true) @@ -78,7 +78,7 @@ end @testset for depth in [18, 34, 50, 101, 152] m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) - if string("seresnet", depth) in PRETRAINED_MODELS + if (SEResNet, depth) in PRETRAINED_MODELS @test acctest(SEResNet(depth, pretrain = true)) else @test_throws ArgumentError SEResNet(depth, pretrain = true) @@ -94,7 +94,7 @@ end @testset for base_width in [4, 8] m = SEResNeXt(depth; cardinality, base_width) @test size(m(x_224)) == (1000, 1) - if string("seresnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS @test acctest(SEResNeXt(depth, pretrain = true)) else @test_throws ArgumentError SEResNeXt(depth, pretrain = true) From 1e509df1f78cca840e19754733c794993c7c3c72 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 9 Jul 2022 23:09:06 +0530 Subject: [PATCH 43/64] Refactor shortcut connections --- Project.toml | 1 + src/Metalhead.jl | 1 + src/convnets/inception.jl | 2 +- src/convnets/resnets/core.jl | 114 ++++++++++++++++++++++++----------- src/layers/scale.jl | 6 +- src/utilities.jl | 22 +++++++ 6 files changed, 108 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index c83b146a5..d054bdc7e 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/Metalhead.jl b/src/Metalhead.jl index e88279270..67731825e 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,6 +7,7 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils +using PartialFunctions using Random import Functors diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index 5823e9737..ba9c935f6 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -530,7 +530,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, inc = inchannels outc = i == nrepeats ? outchannels : inchannels end - push!(layers, x -> relu.(x)) + push!(layers, relu) append!(layers, depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, use_bn = (false, false))) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index c53575c02..295a7698f 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -10,8 +10,9 @@ ## expansion factor of each block and use this to construct the stages of the model. """ - basicblock(inplanes, planes; stride = 1, downsample = identity, reduction_factor = 1, - dilation = 1, first_dilation = dilation, activation = relu, + basicblock(inplanes, planes; stride = 1, downsample = identity, + reduction_factor = 1, dilation = 1, first_dilation = dilation, + activation = relu, connection = addact\$activation, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) @@ -28,6 +29,9 @@ Creates a basic ResNet block. - `dilation`: the dilation of the second convolution. - `first_dilation`: the dilation of the first convolution. - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -38,8 +42,8 @@ Creates a basic ResNet block. attention function. """ function basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, - dilation = 1, first_dilation = dilation, activation = relu, + reduction_factor = 1, dilation = 1, first_dilation = dilation, + activation = relu, connection = addact$activation, norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(basicblock) @@ -53,18 +57,17 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, dilation = dilation, bias = false), norm_layer(outplanes)) attn_layer = attn_fn(outplanes; attn_args...) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, - drop_path)), - activation) + return Parallel(connection, downsample, + Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, + drop_path)) end expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, + activation = relu, connection = addact\$activation, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) Creates a bottleneck ResNet block. @@ -81,6 +84,9 @@ Creates a bottleneck ResNet block. convolution. - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -92,8 +98,8 @@ Creates a bottleneck ResNet block. """ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, norm_layer = BatchNorm, - drop_block = identity, drop_path = identity, + activation = relu, connection = addact$activation, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality @@ -106,10 +112,9 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina norm_layer(width)) conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) attn_layer = attn_fn(outplanes; attn_args...) - return Chain(Parallel(+, downsample, - Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, - attn_layer, drop_path)), - activation) + return Parallel(connection, downsample, + Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, + attn_layer, drop_path)) end expansion_factor(::typeof(bottleneck)) = 4 @@ -209,6 +214,21 @@ function downsample_pool(kernel_size, inplanes, outplanes; stride = 1, dilation norm_layer(outplanes)) end +# Downsample layer which is an identity projection. Uses max pooling +# when the output size is more than the input size. +function downsample_identity(kernel_size, inplanes, outplanes; kwargs...) + if outplanes > inplanes + return Chain(MaxPool((1, 1); stride = 2), + y -> cat_channels(y, + zeros(eltype(y), + size(y, 1), + size(y, 2), + outplanes - inplanes, size(y, 4)))) + else + return identity + end +end + """ downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), stride = 1, dilation = 1, norm_layer = BatchNorm) @@ -228,24 +248,47 @@ This function is almost never used directly or customised by the user. - `dilation`: The dilation of the convolutional layer. - `norm_layer`: The normalisation layer to be used. """ -function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), - stride = 1, dilation = 1, norm_layer = BatchNorm) +function downsample_block(downsample_fns, inplanes, planes, expansion; + kernel_size = (1, 1), stride = 1, dilation = 1, + norm_layer = BatchNorm) + down_fn1, down_fn2 = downsample_fns if stride != 1 || inplanes != planes * expansion - downsample = downsample_fn(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) + downsample = down_fn2(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) else - downsample = identity + downsample = down_fn1(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) end return downsample end +const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv)) + +function _make_downsample_fns(vec::Vector{T}) where {T} + if T <: Symbol + downs = [] + for i in vec + @assert i in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + push!(downs, shortcut_dict[i]) + end + return downs + elseif T <: NTuple{2} + return vec + else + throw(ArgumentError("The shortcut list must be a `Vector` of `Symbol`s or `NTuple{2}`s")) + end +end + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, - downsample_fn = downsample_conv, - drop_rates::NamedTuple, block_args::NamedTuple) + downsample_fns::Vector, drop_rates::NamedTuple, + block_args::NamedTuple) @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" expansion = expansion_factor(block_fn) stages = [] @@ -258,9 +301,10 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride # DropBlock rate dbr = get(drop_rates, :drop_block_rate, 0) ## Construct each stage - for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr))) + for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr), + downsample_fns)) # Number of planes in each stage, number of blocks in each stage, and the drop block rate - planes, num_blocks, drop_block = itr + planes, num_blocks, drop_block, down_fns = itr # Stride calculations for each stage stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride @@ -270,7 +314,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride net_stride *= stride end # Downsample block; either a (default) convolution-based block or a pooling-based block - downsample = downsample_block(downsample_fn, inplanes, planes, expansion; + downsample = downsample_block(down_fns, inplanes, planes, expansion; stride, dilation) ## Construct the blocks for each stage blocks = [] @@ -355,17 +399,19 @@ information. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ -function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, +function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; + inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, - downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), + block_args::NamedTuple = NamedTuple(), drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0), classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false)) ## Feature Blocks channels = [64, 128, 256, 512] + downsample_fns = _make_downsample_fns(downsample_list) stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; - output_stride, downsample_fn, drop_rates, block_args) + output_stride, downsample_fns, drop_rates, block_args) ## Classifier head expansion = expansion_factor(block_fn) num_features = 512 * expansion @@ -384,8 +430,8 @@ function resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride end # block-layer configurations for ResNet and ResNeXt models -const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), - 34 => (basicblock, [3, 4, 6, 3]), - 50 => (bottleneck, [3, 4, 6, 3]), - 101 => (bottleneck, [3, 4, 23, 3]), - 152 => (bottleneck, [3, 8, 36, 3])) +const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2], [:A, :B, :B, :B]), + 34 => (basicblock, [3, 4, 6, 3], [:A, :B, :B, :B]), + 50 => (bottleneck, [3, 4, 6, 3], [:B, :B, :B, :B]), + 101 => (bottleneck, [3, 4, 23, 3], [:B, :B, :B, :B]), + 152 => (bottleneck, [3, 8, 36, 3], [:B, :B, :B, :B])) diff --git a/src/layers/scale.jl b/src/layers/scale.jl index cd55fc97c..965b50f38 100644 --- a/src/layers/scale.jl +++ b/src/layers/scale.jl @@ -4,9 +4,9 @@ Scale the input by a scalar `λ` and applies an activation function to it. Equivalent to `activation.(λ .* x)`. """ -inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) -_input_scale(x, λ, activation) = activation.(λ .* x) -_input_scale(x, λ, ::typeof(identity)) = λ .* x +inputscale(λ; activation = identity) = _input_scale$(λ, activation) +_input_scale(λ, activation, x) = activation.(λ .* x) +_input_scale(λ, ::typeof(identity), x) = λ .* x """ LayerScale(λ, planes::Integer) diff --git a/src/utilities.jl b/src/utilities.jl index 9c29350bd..938c598f0 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -9,6 +9,28 @@ function _round_channels(channels, divisor, min_value = divisor) return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels end +""" + addact(activation = relu, xs...) + +Convenience function for applying an activation function to the output after +summing up the input arrays. Useful as the `connection` argument for the block +function in [`resnet`](#). + +See also [`reluadd`](#). +""" +addact(activation = relu, xs...) = activation(sum(tuple(xs...))) + +""" + actadd(activation = relu, xs...) + +Convenience function for adding input arrays after applying an activation +function to them. Useful as the `connection` argument for the block function in +[`resnet`](#). + +See also [`addrelu`](#). +""" +actadd(activation = relu, xs...) = sum(activation.(tuple(xs...))) + """ cat_channels(x, y, zs...) From e4930f1513aa40ee722f8f3f25b9359ba0cc1d8b Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 10 Jul 2022 10:27:12 +0530 Subject: [PATCH 44/64] Generalise `resnet` further --- src/convnets/resnets/core.jl | 17 +++++++++-------- src/convnets/resnets/resnet.jl | 8 +++++++- src/convnets/resnets/seresnet.jl | 8 ++++---- src/layers/Layers.jl | 1 + test/convnets.jl | 2 +- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 295a7698f..289a7812d 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -399,7 +399,8 @@ information. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ -function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; +function resnet(block_fn, layers, + downsample_list::Vector = collect(:B for _ in 1:length(layers)); inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, block_args::NamedTuple = NamedTuple(), @@ -408,7 +409,7 @@ function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false)) ## Feature Blocks - channels = [64, 128, 256, 512] + channels = collect(64 * 2^i for i in range(0, length(layers))) downsample_fns = _make_downsample_fns(downsample_list) stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fns, drop_rates, block_args) @@ -429,9 +430,9 @@ function resnet(block_fn, layers, downsample_list::Vector = [:A, :B, :B, :B]; return Chain(Chain(stem, stage_blocks), classifier) end -# block-layer configurations for ResNet and ResNeXt models -const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2], [:A, :B, :B, :B]), - 34 => (basicblock, [3, 4, 6, 3], [:A, :B, :B, :B]), - 50 => (bottleneck, [3, 4, 6, 3], [:B, :B, :B, :B]), - 101 => (bottleneck, [3, 4, 23, 3], [:B, :B, :B, :B]), - 152 => (bottleneck, [3, 8, 36, 3], [:B, :B, :B, :B])) +# block-layer configurations for ResNet-like models +const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index cd26e69a3..3356ef225 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -1,3 +1,9 @@ +const resnet_shortcuts = Dict(18 => [:A, :B, :B, :B], + 34 => [:A, :B, :B, :B], + 50 => [:B, :B, :B, :B], + 101 => [:B, :B, :B, :B], + 152 => [:B, :B, :B, :B]) + """ ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @@ -27,7 +33,7 @@ end function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses) + layers = resnet(resnet_config[depth]..., resnet_shortcuts[depth]; inchannels, nclasses) if pretrain loadpretrain!(layers, string("ResNet", depth)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 58c6a0607..605c074d6 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -6,7 +6,7 @@ Creates a SEResNet model with the specified depth. # Arguments - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: the number of input channels. - `nclasses`: the number of output classes @@ -25,8 +25,8 @@ end (m::SEResNet)(x) = m.layers(x) function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" + @assert depth in [50, 101, 152] + "Invalid depth. Must be one of [50, 101, 152]" layers = resnet(resnet_config[depth]...; inchannels, nclasses, block_args = (; attn_fn = squeeze_excite)) if pretrain @@ -44,7 +44,7 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. # Arguments - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. - `base_width`: the number of feature maps in each group. diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 2c4b11e5a..e0b870fe9 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -8,6 +8,7 @@ using Functors using ChainRulesCore using Statistics using MLUtils +using PartialFunctions using Random include("../utilities.jl") diff --git a/test/convnets.jl b/test/convnets.jl index f99b204dc..601f51421 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -75,7 +75,7 @@ end end @testset "SEResNet" begin - @testset for depth in [18, 34, 50, 101, 152] + @testset for depth in [50, 101, 152] m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) if (SEResNet, depth) in PRETRAINED_MODELS From 80bdcded05fdc735f29efb4f6803813e55045fbe Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 10 Jul 2022 14:20:43 +0530 Subject: [PATCH 45/64] Documentation And make `downsample_opts` a smidge easier to work with. Also, a wee bit o' formatting and cleanup. --- src/convnets/resnets/core.jl | 94 ++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 289a7812d..e55eb7a76 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -130,15 +130,14 @@ on how to use this function. - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + `:default`: Builds a stem based on the default ResNet stem, which consists of a single - 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 - max pooling layer with stride 2. - + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using a - deeper stem with 3 successive 3x3 convolutions having normalisation layers - after each one. This is followed by a 3x3 max pooling layer with stride 2. + 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling + layer with stride 2. + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using + a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each + one. This is followed by a 3x3 max pooling layer with stride 2. + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second - convolution. This is an experimental variant from the `timm` library - in Python that shows peformance improvements over the `:deep` stem - in some cases. + convolution. This is an experimental variant from the `timm` library in Python that + shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - `replace_stem_pool`: Whether to replace the default 3x3 max pooling layer with a @@ -253,20 +252,27 @@ function downsample_block(downsample_fns, inplanes, planes, expansion; norm_layer = BatchNorm) down_fn1, down_fn2 = downsample_fns if stride != 1 || inplanes != planes * expansion - downsample = down_fn2(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) + return down_fn1(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) else - downsample = down_fn1(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) + return down_fn2(kernel_size, inplanes, planes * expansion; + stride, dilation, norm_layer) end - return downsample end +# Shortcut configurations for the ResNet models const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv)) - -function _make_downsample_fns(vec::Vector{T}) where {T} + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) + +# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is +# specified as a `Vector` of `Symbol`s. This is used to make the downsample +# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is +# already an `NTuple{2}` of functions, it is returned unchanged. +function _make_downsample_fns(vec::Vector{T}, layers) where {T} + @assert length(vec) == length(layers) + "The length of the downsample `Vector` must match the number of stages" if T <: Symbol downs = [] for i in vec @@ -282,6 +288,13 @@ function _make_downsample_fns(vec::Vector{T}) where {T} end end +function _make_downsample_fns(sym::Symbol, layers) + @assert sym in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + return collect(shortcut_dict[sym] for _ in 1:length(layers)) +end +_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function @@ -345,13 +358,14 @@ function _drop_blocks(drop_block_prob = 0.0) end """ - resnet(block_fn, layers; inchannels = 3, nclasses = 1000, output_stride = 32, - stem = first(resnet_stem(; inchannels)), inplanes = 64, - downsample_fn = downsample_conv, block_args::NamedTuple = NamedTuple(), - drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), - use_conv = false)) + resnet(block_fn, layers, downsample_opt = :B; + inchannels = 3, nclasses = 1000, output_stride = 32, + stem = first(resnet_stem(; inchannels)), inplanes = 64, + block_args::NamedTuple = NamedTuple(), + drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = 0.0), + classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), + use_conv = false)) This function creates the layers for many ResNet-like models. See the user guide for more information. @@ -360,7 +374,7 @@ information. If you are an end-user trying to use ResNet-like models, you should consider [`ResNet`](#) and similar higher-level functions instead. This version is significantly more customisable - at the cost of being more complicated. + at the cost of being more significantly more complicated. # Arguments @@ -371,6 +385,25 @@ information. - `layers`: A list of integers specifying the number of blocks in each stage. For example, `[3, 4, 6, 3]` would mean that the network would have 4 stages, with 3, 4, 6 and 3 blocks in each. + - `downsample_opt`: Downsampling options. This can be any one of the following: + + + A single `Symbol` specifying the downsample option to use for all stages. The default + is :B, which corresponds to a 1x1 convolution-based downsample for every stage except + the first, which uses an identity projection. The other options are `:A`, which uses + an identity projection for all stages, `:C`, which uses a convolution-based + downsample for all stages and `:D`, which uses a max-pooling-based downsample for every + stage except the first, which uses an identity projection. `:A`, `:B` and `:C` are + are described in the [paper](https://arxiv.org/abs/1512.03385), while `:D` is + described in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) paper. + + A `Vector` of `Symbol`s specifying the downsample options to use for each stage. The + choices are the same as the single option above. The length of this `Vector` must be + the same as the length of `layers`. + + A `Vector` of `NTuple{2}`s specifying the downsample functions to use for each stage. + The functions have to be passed in directly here - see [`downsample_identity`](#), + [`downsample_conv`](#), and [`downsample_pool`](#). The first element of each tuple is + the downsample function to use for the first stage, and the second element is the + function to use for the rest of the stages. The length of this `Vector` must be the + same as the length of `layers`. - `nclasses`: The number of output classes. - `inchannels`: The number of input channels. - `output_stride`: The total stride of the network i.e. the amount by which the input is @@ -381,10 +414,6 @@ information. function for creating a stem (see [`resnet_stem`](#)) but you can also create your own (although this is not usually necessary). - `inplanes`: The number of output channels from the stem. - - `downsample_type`: The type of downsampling to use. Either `:conv` or `:pool`. The former - uses a traditional convolution-based downsampling, while the latter is an - average-pooling-based downsampling that was suggested in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) - paper. - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed to the block function. For more information regarding valid arguments, see the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). @@ -395,12 +424,13 @@ information. + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). - `classifier_args`: A `NamedTuple` that **must** specify the following arguments: - + `pool_layer`: The adaptive pooling layer to use in the classifier head. + + `pool_layer`: The pooling layer to use in the classifier head. Pass this in with the + arguments to the layer defined. For example, if you want to use an adaptive mean pooling + layer, you would pass in `AdaptiveMeanPool((1, 1))`. + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a `Dense` layer. """ -function resnet(block_fn, layers, - downsample_list::Vector = collect(:B for _ in 1:length(layers)); +function resnet(block_fn, layers, downsample_opt = :B; inchannels = 3, nclasses = 1000, output_stride = 32, stem = first(resnet_stem(; inchannels)), inplanes = 64, block_args::NamedTuple = NamedTuple(), @@ -410,7 +440,7 @@ function resnet(block_fn, layers, use_conv = false)) ## Feature Blocks channels = collect(64 * 2^i for i in range(0, length(layers))) - downsample_fns = _make_downsample_fns(downsample_list) + downsample_fns = _make_downsample_fns(downsample_opt, layers) stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; output_stride, downsample_fns, drop_rates, block_args) ## Classifier head From ab379014be506f9fc2e617db309385050e5fd3dd Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 12 Jul 2022 11:52:46 +0530 Subject: [PATCH 46/64] Add classifier and backbone methods --- src/convnets/resnets/resnet.jl | 3 +++ src/convnets/resnets/resnext.jl | 3 +++ src/convnets/resnets/seresnet.jl | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 3356ef225..ffbee32dc 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -39,3 +39,6 @@ function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 100 end return ResNet(layers) end + +backbone(m::ResNet) = m.layers[1] +classifier(m::ResNet) = m.layers[2] diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 1fa00a7b0..dc1de9464 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -38,3 +38,6 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, end return ResNeXt(layers) end + +backbone(m::ResNeXt) = m.layers[1] +classifier(m::ResNeXt) = m.layers[2] diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 605c074d6..bc57a08fc 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -35,6 +35,9 @@ function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1 return SEResNet(layers) end +backbone(m::SEResNet) = m.layers[1] +classifier(m::SEResNet) = m.layers[2] + """ SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) @@ -75,3 +78,6 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt end return SEResNeXt(layers) end + +backbone(m::SEResNeXt) = m.layers[1] +classifier(m::SEResNeXt) = m.layers[2] From 68abbb79e293bda94163718fbc790e01a1d0c689 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 17 Jul 2022 09:30:32 +0530 Subject: [PATCH 47/64] Refactor of resnet core --- src/convnets/convnext.jl | 2 +- src/convnets/efficientnet.jl | 3 +- src/convnets/resnets/core.jl | 531 ++++++++++--------------------- src/convnets/resnets/resnext.jl | 3 +- src/convnets/resnets/seresnet.jl | 7 +- src/layers/Layers.jl | 2 +- src/layers/drop.jl | 10 + src/other/mlpmixer.jl | 2 +- 8 files changed, 193 insertions(+), 367 deletions(-) diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 6ced7eeb9..113a35142 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -51,7 +51,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 push!(downsample_layers, downsample_layer) end stages = [] - dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths)) + dp_rates = droppath_rates(drop_path_rate; depth = sum(depths)) cur = 0 for i in eachindex(depths) push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]]) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index da9000468..02c5b6eb6 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -79,8 +79,7 @@ const efficientnet_block_configs = [ # w: width scaling # d: depth scaling # r: image resolution -const efficientnet_global_configs = Dict( - # (r, (w, d)) +const efficientnet_global_configs = Dict(# (r, (w, d)) :b0 => (224, (1.0, 1.0)), :b1 => (240, (1.0, 1.1)), :b2 => (260, (1.1, 1.2)), diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index e55eb7a76..4b76a3d23 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,122 +1,170 @@ ## It is recommended to check out the user guide for more information. -### ResNet blocks -## These functions return a block to be used inside of a ResNet model. -## The individual arguments are explained in the documentation of the functions. -## Note that for these blocks to be used by the `_make_blocks` function, they must define -## a dispatch `expansion(::typeof(fn))` that returns the expansion factor of the block -## (i.e. the multiplicative factor by which the number of channels in the input is increased). -## The `_make_blocks` function will then call the `expansion` function to determine the -## expansion factor of each block and use this to construct the stages of the model. +abstract type AbstractResNetBlock end -""" - basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, dilation = 1, first_dilation = dilation, - activation = relu, connection = addact\$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) +struct basicblock <: AbstractResNetBlock + inplanes::Integer + planes::Integer + reduction_factor::Integer +end +function basicblock(inplanes, planes, reduction_factor, base_width, cardinality) + @assert base_width == 64 "`base_width` must be 64 for `basicblock`" + @assert cardinality == 1 "`cardinality` must be 1 for `basicblock`" + return basicblock(inplanes, planes, reduction_factor) +end +expansion_factor(::basicblock) = 1 + +struct bottleneck <: AbstractResNetBlock + inplanes::Integer + planes::Integer + reduction_factor::Integer + base_width::Integer + cardinality::Integer +end +expansion_factor(::bottleneck) = 4 -Creates a basic ResNet block. +# Downsample layer using convolutions. +function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm) + return Chain(Conv((1, 1), inplanes => outplanes; stride, pad = SamePad(), bias = false), + norm_layer(outplanes)) +end -# Arguments +# Downsample layer using max pooling +function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm) + pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) + return Chain(pool, + Conv((1, 1), inplanes => outplanes; bias = false), + norm_layer(outplanes)) +end - - `inplanes`: number of input feature maps - - `planes`: number of feature maps for the block - - `stride`: the stride of the block - - `downsample`: the downsampling function to use - - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first - convolution. - - `dilation`: the dilation of the second convolution. - - `first_dilation`: the dilation of the first convolution. - - `activation`: the activation function to use. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - - `norm_layer`: the normalization layer to use. - - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. - - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the - attention function. -""" -function basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, dilation = 1, first_dilation = dilation, - activation = relu, connection = addact$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) - expansion = expansion_factor(basicblock) - first_planes = planes ÷ reduction_factor - outplanes = planes * expansion - conv_bn1 = Chain(Conv((3, 3), inplanes => first_planes; stride, pad = first_dilation, - dilation = first_dilation, bias = false), - norm_layer(first_planes)) - drop_block = drop_block - conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = dilation, - dilation = dilation, bias = false), - norm_layer(outplanes)) - attn_layer = attn_fn(outplanes; attn_args...) - return Parallel(connection, downsample, - Chain(conv_bn1, drop_block, activation, conv_bn2, attn_layer, - drop_path)) +# Downsample layer which is an identity projection. Uses max pooling +# when the output size is more than the input size. +function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) + if outplanes > inplanes + return Chain(MaxPool((1, 1); stride = 2), + y -> cat_channels(y, + zeros(eltype(y), + size(y, 1), + size(y, 2), + outplanes - inplanes, size(y, 4)))) + else + return identity + end end -expansion_factor(::typeof(basicblock)) = 1 -""" - bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, connection = addact\$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) +function downsample_block(downsample_fns, inplanes, planes, expansion; stride = 1, + norm_layer = BatchNorm) + down_fn1, down_fn2 = downsample_fns + if stride != 1 || inplanes != planes * expansion + return down_fn1(inplanes, planes * expansion; stride, norm_layer) + else + return down_fn2(inplanes, planes * expansion; stride, norm_layer) + end +end + +# Shortcut configurations for the ResNet models +const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) -Creates a bottleneck ResNet block. +# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is +# specified as a `Vector` of `Symbol`s. This is used to make the downsample +# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is +# already an `NTuple{2}` of functions, it is returned unchanged. +function _make_downsample_fns(vec::Vector{<:Symbol}, layers) + downs = [] + for i in vec + @assert i in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + push!(downs, shortcut_dict[i]) + end + return downs +end +function _make_downsample_fns(sym::Symbol, layers) + @assert sym in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + return collect(shortcut_dict[sym] for _ in 1:length(layers)) +end +_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec +_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) -# Arguments +# Stride for each block in the ResNet model +function get_stride(::AbstractResNetBlock, idxs::NTuple{2, Integer}) + return (idxs[1] == 1 || idxs[1] == 1) ? 2 : 1 +end - - `inplanes`: number of input feature maps - - `planes`: number of feature maps for the block - - `stride`: the stride of the block - - `downsample`: the downsampling function to use - - `cardinality`: the number of groups in the convolution. - - `base_width`: the number of output feature maps for each convolutional group. - - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first - convolution. - - `first_dilation`: the dilation of the 3x3 convolution. - - `activation`: the activation function to use. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - - `norm_layer`: the normalization layer to use. - - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. - - `attn_args`: a NamedTuple that contains none, some or all of the arguments to be passed to the - attention function. -""" -function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, connection = addact$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity, attn_args::NamedTuple = NamedTuple()) - expansion = expansion_factor(bottleneck) - width = floor(Int, planes * (base_width / 64)) * cardinality - first_planes = width ÷ reduction_factor - outplanes = planes * expansion - conv_bn1 = Chain(Conv((1, 1), inplanes => first_planes; bias = false), - norm_layer(first_planes, activation)) - conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = first_dilation, - dilation = first_dilation, groups = cardinality, bias = false), +# returns `DropBlock`s for each stage of the ResNet +function _drop_blocks(drop_block_rate::AbstractFloat) + return [ + identity, identity, + DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00) + ] +end + +function _make_layers(block::basicblock, norm_layer, stride) + first_planes = block.planes ÷ block.reduction_factor + outplanes = block.planes * expansion_factor(block) + conv_bn1 = Chain(Conv((3, 3), block.inplanes => first_planes; stride, pad = 1, bias = false), + norm_layer(first_planes)) + conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = 1, bias = false), + norm_layer(outplanes)) + layers = [] + push!(layers, conv_bn1, conv_bn2) + return layers +end + +function _make_layers(block::bottleneck, norm_layer, stride) + width = fld(block.planes * block.base_width, 64) * block.cardinality + first_planes = width ÷ block.reduction_factor + outplanes = block.planes * expansion_factor(block) + conv_bn1 = Chain(Conv((1, 1), block.inplanes => first_planes; bias = false), + norm_layer(first_planes)) + conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = 1, + groups = block.cardinality, bias = false), norm_layer(width)) conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) - attn_layer = attn_fn(outplanes; attn_args...) - return Parallel(connection, downsample, - Chain(conv_bn1, conv_bn2, drop_block, activation, conv_bn3, - attn_layer, drop_path)) + layers = [] + push!(layers, conv_bn1, conv_bn2, conv_bn3) + return layers +end + +function make_block(block::T, idxs::NTuple{2, Integer}; kwargs...) where {T <: AbstractResNetBlock} + stage_idx, block_idx = idxs + kwargs = Dict(kwargs) + stride = get(kwargs, :stride_fn, get_stride)(block, idxs) + expansion = expansion_factor(block) + norm_layer = get(kwargs, :norm_layer, BatchNorm) + layers = _make_layers(block, norm_layer, stride) + activation = get(kwargs, :activation, relu) + insert!(layers, 2, activation) + if T <: bottleneck + insert!(layers, 4, activation) + end + if haskey(kwargs, :drop_block_rate) + layer_idx = T <: basicblock ? 2 : 3 + dropblock = _drop_blocks(kwargs[:drop_block_rate])[stage_idx] + insert!(layers, layer_idx, dropblock) + end + if haskey(kwargs, :attn_fn) + attn_layer = kwargs[:attn_fn](block.planes) + push!(layers, attn_layer) + end + if haskey(kwargs, :drop_path_rate) + droppath = DropPath(kwargs[:droppath_rates][block_idx]) + push!(layers, droppath) + end + if haskey(kwargs, :downsample_fns) + downsample_tup = kwargs[:downsample_fns][stage_idx] + downsample = downsample_block(downsample_tup, block.inplanes, block.planes, expansion; stride) + connection = get(kwargs, :connection, addact)$activation + return Parallel(connection, downsample, Chain(layers...)) + else + return Chain(layers...) + end end -expansion_factor(::typeof(bottleneck)) = 4 """ resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, @@ -145,8 +193,8 @@ on how to use this function. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ -function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, - norm_layer = BatchNorm, activation = relu) +function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, + replace_stem_pool::Bool = false, norm_layer = BatchNorm, activation = relu) @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem @@ -181,272 +229,43 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = return Chain(conv1, bn1, stempool), inplanes end -### Downsampling layers -## These will almost never be used directly. They are used by the `_make_blocks` function to -## build the downsampling layers. In most cases, these defaults will not need to be changed. -## If you wish to write your own ResNet model using the `_make_blocks` function, you can use -## this function to build the downsampling layers. - -# Downsample layer using convolutions. -function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size - dilation = kernel_size[1] > 1 ? dilation : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, inplanes => outplanes; stride, pad, - dilation, bias = false), - norm_layer(outplanes)) -end - -# Downsample layer using max pooling -function downsample_pool(kernel_size, inplanes, outplanes; stride = 1, dilation = 1, - norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = MeanPool((2, 2); stride = avg_stride, pad) - end - return Chain(pool, - Conv((1, 1), inplanes => outplanes; bias = false), - norm_layer(outplanes)) -end - -# Downsample layer which is an identity projection. Uses max pooling -# when the output size is more than the input size. -function downsample_identity(kernel_size, inplanes, outplanes; kwargs...) - if outplanes > inplanes - return Chain(MaxPool((1, 1); stride = 2), - y -> cat_channels(y, - zeros(eltype(y), - size(y, 1), - size(y, 2), - outplanes - inplanes, size(y, 4)))) - else - return identity - end -end - -""" - downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1), - stride = 1, dilation = 1, norm_layer = BatchNorm) - -Wrapper function that makes it easier to build a downsample block inside a ResNet model. -This function is almost never used directly or customised by the user. - -# Arguments - - - `downsample_fn`: The function to use for downsampling in skip connections. Recommended usage - is passing in either `downsample_conv` or `downsample_pool`. - - `inplanes`: The number of input feature maps. - - `planes`: The number of output feature maps. - - `expansion`: The expansion factor of the block. - - `kernel_size`: The size of the convolutional kernel. - - `stride`: The stride of the convolutional layer. - - `dilation`: The dilation of the convolutional layer. - - `norm_layer`: The normalisation layer to be used. -""" -function downsample_block(downsample_fns, inplanes, planes, expansion; - kernel_size = (1, 1), stride = 1, dilation = 1, - norm_layer = BatchNorm) - down_fn1, down_fn2 = downsample_fns - if stride != 1 || inplanes != planes * expansion - return down_fn1(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) - else - return down_fn2(kernel_size, inplanes, planes * expansion; - stride, dilation, norm_layer) - end -end - -# Shortcut configurations for the ResNet models -const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), - :B => (downsample_conv, downsample_identity), - :C => (downsample_conv, downsample_conv), - :D => (downsample_pool, downsample_identity)) - -# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is -# specified as a `Vector` of `Symbol`s. This is used to make the downsample -# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is -# already an `NTuple{2}` of functions, it is returned unchanged. -function _make_downsample_fns(vec::Vector{T}, layers) where {T} - @assert length(vec) == length(layers) - "The length of the downsample `Vector` must match the number of stages" - if T <: Symbol - downs = [] - for i in vec - @assert i in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - push!(downs, shortcut_dict[i]) - end - return downs - elseif T <: NTuple{2} - return vec - else - throw(ArgumentError("The shortcut list must be a `Vector` of `Symbol`s or `NTuple{2}`s")) - end -end - -function _make_downsample_fns(sym::Symbol, layers) - @assert sym in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - return collect(shortcut_dict[sym] for _ in 1:length(layers)) -end -_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) - # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32, - downsample_fns::Vector, drop_rates::NamedTuple, - block_args::NamedTuple) - @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" - expansion = expansion_factor(block_fn) +function resnet_stages(block_type, channels, block_repeats, inplanes; kwargs...) stages = [] - net_block_idx = 1 - net_stride = 4 - dilation = prev_dilation = 1 - # Stochastic depth linear decay rule (DropPath) - dp_rates = LinRange{Float32}(0.0, get(drop_rates, :drop_path_rate, 0), - sum(block_repeats)) - # DropBlock rate - dbr = get(drop_rates, :drop_block_rate, 0) + kwargs = Dict(kwargs) + cardinality = get(kwargs, :cardinality, 1) + base_width = get(kwargs, :base_width, 64) + reduction_factor = get(kwargs, :reduction_factor, 1) ## Construct each stage - for (stage_idx, itr) in enumerate(zip(channels, block_repeats, _drop_blocks(dbr), - downsample_fns)) - # Number of planes in each stage, number of blocks in each stage, and the drop block rate - planes, num_blocks, drop_block, down_fns = itr - # Stride calculations for each stage - stride = stage_idx == 1 ? 1 : 2 - if net_stride >= output_stride - dilation *= stride - stride = 1 - else - net_stride *= stride - end - # Downsample block; either a (default) convolution-based block or a pooling-based block - downsample = downsample_block(down_fns, inplanes, planes, expansion; - stride, dilation) + for (stage_idx, (planes, num_blocks)) in enumerate(zip(channels, block_repeats)) ## Construct the blocks for each stage blocks = [] for block_idx in 1:num_blocks - # Different behaviour for the first block of each stage - downsample = block_idx == 1 ? downsample : identity - stride = block_idx == 1 ? stride : 1 - push!(blocks, - block_fn(inplanes, planes; stride, downsample, - first_dilation = prev_dilation, - drop_path = DropPath(dp_rates[block_idx]), drop_block, - block_args...)) - prev_dilation = dilation - inplanes = planes * expansion - net_block_idx += 1 + block_struct = block_type(inplanes, planes, reduction_factor, base_width, cardinality) + block = make_block(block_struct, (stage_idx, block_idx); kwargs...) + inplanes = planes * expansion_factor(block_struct) + push!(blocks, block) end push!(stages, Chain(blocks...)) end - return Chain(stages...) -end - -# returns `DropBlock`s for each stage of the ResNet -function _drop_blocks(drop_block_prob = 0.0) - return [ - identity, identity, - DropBlock(drop_block_prob, 5, 0.25), DropBlock(drop_block_prob, 3, 1.00), - ] + return Chain(stages...), inplanes end -""" - resnet(block_fn, layers, downsample_opt = :B; - inchannels = 3, nclasses = 1000, output_stride = 32, - stem = first(resnet_stem(; inchannels)), inplanes = 64, - block_args::NamedTuple = NamedTuple(), - drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), - use_conv = false)) - -This function creates the layers for many ResNet-like models. See the user guide for more -information. - -!!! note - - If you are an end-user trying to use ResNet-like models, you should consider [`ResNet`](#) - and similar higher-level functions instead. This version is significantly more customisable - at the cost of being more significantly more complicated. - -# Arguments - - - `block_fn`: The type of block to use inside the ResNet model. Must be either `:basicblock`, - which is the standard ResNet block, or `:bottleneck`, which is the ResNet block with a - bottleneck structure. See the [paper](https://arxiv.org/abs/1512.03385) for more details. - - - `layers`: A list of integers specifying the number of blocks in each stage. For example, - `[3, 4, 6, 3]` would mean that the network would have 4 stages, with 3, 4, 6 and 3 blocks in - each. - - `downsample_opt`: Downsampling options. This can be any one of the following: - - + A single `Symbol` specifying the downsample option to use for all stages. The default - is :B, which corresponds to a 1x1 convolution-based downsample for every stage except - the first, which uses an identity projection. The other options are `:A`, which uses - an identity projection for all stages, `:C`, which uses a convolution-based - downsample for all stages and `:D`, which uses a max-pooling-based downsample for every - stage except the first, which uses an identity projection. `:A`, `:B` and `:C` are - are described in the [paper](https://arxiv.org/abs/1512.03385), while `:D` is - described in the [Bag of Tricks](https://arxiv.org/abs/1812.01187) paper. - + A `Vector` of `Symbol`s specifying the downsample options to use for each stage. The - choices are the same as the single option above. The length of this `Vector` must be - the same as the length of `layers`. - + A `Vector` of `NTuple{2}`s specifying the downsample functions to use for each stage. - The functions have to be passed in directly here - see [`downsample_identity`](#), - [`downsample_conv`](#), and [`downsample_pool`](#). The first element of each tuple is - the downsample function to use for the first stage, and the second element is the - function to use for the rest of the stages. The length of this `Vector` must be the - same as the length of `layers`. - - `nclasses`: The number of output classes. - - `inchannels`: The number of input channels. - - `output_stride`: The total stride of the network i.e. the amount by which the input is - downsampled throughout the network. This is used to determine the output size from the - backbone of the network. Must be one of `[8, 16, 32]`. - - `stem`: A constructed ResNet stem, passed in to be used in the model. `inplanes` should be - set to the number of output channels from this stem. Metalhead provides an in-built - function for creating a stem (see [`resnet_stem`](#)) but you can also create your - own (although this is not usually necessary). - - `inplanes`: The number of output channels from the stem. - - `block_args`: A `NamedTuple` that may define none, some or all the arguments to be passed - to the block function. For more information regarding valid arguments, see - the documentation for the block functions ([`basicblock`](#), [`bottleneck`](#)). - - `drop_rates`: A `NamedTuple` that may define none, some or all of the following: - - + `dropout_rate`: The rate of dropout to be used in the classifier head. - + `drop_path_rate`: Stochastic depth implemented using [`DropPath`](#). - + `drop_block_rate`: `DropBlock` regularisation implemented using [`DropBlock`](#). - - `classifier_args`: A `NamedTuple` that **must** specify the following arguments: - - + `pool_layer`: The pooling layer to use in the classifier head. Pass this in with the - arguments to the layer defined. For example, if you want to use an adaptive mean pooling - layer, you would pass in `AdaptiveMeanPool((1, 1))`. - + `use_conv`: Whether to use a 1x1 convolutional layer in the classifier head instead of a - `Dense` layer. -""" -function resnet(block_fn, layers, downsample_opt = :B; - inchannels = 3, nclasses = 1000, output_stride = 32, - stem = first(resnet_stem(; inchannels)), inplanes = 64, - block_args::NamedTuple = NamedTuple(), - drop_rates::NamedTuple = (dropout_rate = 0.0, drop_path_rate = 0.0, - drop_block_rate = 0.0), - classifier_args::NamedTuple = (pool_layer = AdaptiveMeanPool((1, 1)), - use_conv = false)) +function resnet(block_fn, layers, downsample_opt = :B; inchannels::Integer = 3, + nclasses::Integer = 1000, stem = first(resnet_stem(; inchannels)), + inplanes::Integer = 64, kwargs...) + kwargs = Dict(kwargs) ## Feature Blocks channels = collect(64 * 2^i for i in range(0, length(layers))) downsample_fns = _make_downsample_fns(downsample_opt, layers) - stage_blocks = _make_blocks(block_fn, channels, layers, inplanes; - output_stride, downsample_fns, drop_rates, block_args) + stage_blocks, num_features = resnet_stages(block_fn, channels, layers, inplanes; downsample_fns, kwargs...) ## Classifier head - expansion = expansion_factor(block_fn) - num_features = 512 * expansion - pool_layer, use_conv = classifier_args + # num_features = 512 * expansion_factor(block_fn) + pool_layer = get(kwargs, :pool_layer, AdaptiveMeanPool((1, 1))) + use_conv = get(kwargs, :use_conv, false) # Pooling if pool_layer === identity @assert use_conv @@ -456,7 +275,7 @@ function resnet(block_fn, layers, downsample_opt = :B; global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer # Fully-connected layer fc = create_fc(num_features, nclasses; use_conv) - classifier = Chain(global_pool, Dropout(get(drop_rates, :dropout_rate, 0)), fc) + classifier = Chain(global_pool, Dropout(get(kwargs, :dropout_rate, 0)), fc) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index dc1de9464..cee2e4757 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -31,8 +31,7 @@ function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width)) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index bc57a08fc..1de2f2195 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -27,8 +27,7 @@ end function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; attn_fn = squeeze_excite)) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, attn_fn = _ -> squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNet", depth)) end @@ -71,8 +70,8 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, - block_args = (; cardinality, base_width, attn_fn = squeeze_excite)) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width, + attn_fn = _ -> squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index e0b870fe9..4269411ef 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -29,7 +29,7 @@ include("conv.jl") export conv_bn, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection include("drop.jl") -export DropPath, DropBlock +export DropBlock, DropPath, droppath_rates include("selayers.jl") export squeeze_excite, effective_squeeze_excite diff --git a/src/layers/drop.jl b/src/layers/drop.jl index dc6cb3c54..981a05efc 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -146,3 +146,13 @@ equivalent to `identity`. on the CPU. """ DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity + +""" + droppath_rates(drop_path_rate::AbstractFloat = 0.0; depth) + +Returns the drop path rates for a given depth using the linear scaling rule +((reference)[https://arxiv.org/abs/1603.09382]) +""" +function droppath_rates(drop_path_rate::AbstractFloat = 0.0; depth) + return LinRange{Float32}(0.0, drop_path_rate, depth) +end \ No newline at end of file diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 48f1efd8c..ec74c79d6 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -59,7 +59,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, embedplanes = 512, drop_path_rate = 0.0, depth = 12, nclasses = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) - dp_rates = LinRange{Float32}(0.0, drop_path_rate, depth) + dp_rates = droppath_rates(drop_path_rate; depth) layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], kwargs...) From 7ad362ba7e2335c05e529c389b07338165453e83 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 22 Jul 2022 08:38:18 +0530 Subject: [PATCH 48/64] Refactor of resnet core II Closures is the name of the game --- .github/workflows/CI.yml | 4 +- src/convnets/convnext.jl | 2 +- src/convnets/efficientnet.jl | 3 +- src/convnets/resnets/core.jl | 312 ++++++++++++++++++------------- src/convnets/resnets/seresnet.jl | 7 +- src/layers/drop.jl | 12 +- src/layers/selayers.jl | 16 +- src/other/mlpmixer.jl | 2 +- src/utilities.jl | 9 + 9 files changed, 211 insertions(+), 156 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c383351bc..e148bdff2 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -33,8 +33,8 @@ jobs: - 'r"/*/Inception/Inceptionv*"' - '["InceptionResNetv2", "Xception"]' - '"DenseNet"' - - '"ConvNeXt"' - - '"ConvMixer"' + - '["ConvNeXt", "ConvMixer"]' + # - '"ConvMixer"' - '"ViT"' - '"Other"' steps: diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 113a35142..383e5d128 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -51,7 +51,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 push!(downsample_layers, downsample_layer) end stages = [] - dp_rates = droppath_rates(drop_path_rate; depth = sum(depths)) + dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths)) cur = 0 for i in eachindex(depths) push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]]) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 02c5b6eb6..7236d5d44 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -79,8 +79,7 @@ const efficientnet_block_configs = [ # w: width scaling # d: depth scaling # r: image resolution -const efficientnet_global_configs = Dict(# (r, (w, d)) - :b0 => (224, (1.0, 1.0)), +const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)), :b1 => (240, (1.0, 1.1)), :b2 => (260, (1.1, 1.2)), :b3 => (300, (1.2, 1.4)), diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 4b76a3d23..284746a60 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,27 +1,105 @@ ## It is recommended to check out the user guide for more information. -abstract type AbstractResNetBlock end +""" + basicblock(inplanes, planes; stride = 1, downsample = identity, + reduction_factor = 1, dilation = 1, first_dilation = dilation, + activation = relu, connection = addact\$activation, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) -struct basicblock <: AbstractResNetBlock - inplanes::Integer - planes::Integer - reduction_factor::Integer -end -function basicblock(inplanes, planes, reduction_factor, base_width, cardinality) - @assert base_width == 64 "`base_width` must be 64 for `basicblock`" - @assert cardinality == 1 "`cardinality` must be 1 for `basicblock`" - return basicblock(inplanes, planes, reduction_factor) +Creates a basic ResNet block. + +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `downsample`: the downsampling function to use + - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first + convolution. + - `dilation`: the dilation of the second convolution. + - `first_dilation`: the dilation of the first convolution. + - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. +""" +function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + expansion = expansion_factor(basicblock) + first_planes = planes ÷ reduction_factor + outplanes = planes * expansion + conv_bn1 = [Conv((3, 3), inplanes => first_planes; stride, pad = 1, bias = false), + norm_layer(first_planes)] + conv_bn2 = [Conv((3, 3), first_planes => outplanes; pad = 1, bias = false), + norm_layer(outplanes)] + layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), + drop_path] + filter!(x -> x !== identity, layers) + return Chain(layers...) end -expansion_factor(::basicblock) = 1 +expansion_factor(::typeof(basicblock)) = 1 + +""" + bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, + base_width = 64, reduction_factor = 1, first_dilation = 1, + activation = relu, connection = addact\$activation, + norm_layer = BatchNorm, drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + +Creates a bottleneck ResNet block. + +# Arguments -struct bottleneck <: AbstractResNetBlock - inplanes::Integer - planes::Integer - reduction_factor::Integer - base_width::Integer - cardinality::Integer + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `downsample`: the downsampling function to use + - `cardinality`: the number of groups in the convolution. + - `base_width`: the number of output feature maps for each convolutional group. + - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first + convolution. + - `first_dilation`: the dilation of the 3x3 convolution. + - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` + function and passed in. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. +""" +function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, + reduction_factor = 1, activation = relu, norm_layer = BatchNorm, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + expansion = 4 + width = floor(Int, planes * (base_width / 64)) * cardinality + first_planes = width ÷ reduction_factor + outplanes = planes * expansion + conv_bn1 = [Conv((1, 1), inplanes => first_planes; bias = false), + norm_layer(first_planes, activation)] + conv_bn2 = [ + Conv((3, 3), first_planes => width; stride, pad = 1, + groups = cardinality, + bias = false), + norm_layer(width)] + conv_bn3 = [Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)] + layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., + attn_fn(outplanes), drop_path] + filter!(x -> x !== identity, layers) + return Chain(layers...) end -expansion_factor(::bottleneck) = 4 +expansion_factor(::typeof(bottleneck)) = 4 # Downsample layer using convolutions. function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, @@ -54,16 +132,6 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) end end -function downsample_block(downsample_fns, inplanes, planes, expansion; stride = 1, - norm_layer = BatchNorm) - down_fn1, down_fn2 = downsample_fns - if stride != 1 || inplanes != planes * expansion - return down_fn1(inplanes, planes * expansion; stride, norm_layer) - else - return down_fn2(inplanes, planes * expansion; stride, norm_layer) - end -end - # Shortcut configurations for the ResNet models const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), @@ -77,9 +145,9 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), function _make_downsample_fns(vec::Vector{<:Symbol}, layers) downs = [] for i in vec - @assert i in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - push!(downs, shortcut_dict[i]) + @assert i in keys(shortcut_dict) + "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" + push!(downs, shortcut_dict[i]) end return downs end @@ -92,80 +160,17 @@ _make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec _make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) # Stride for each block in the ResNet model -function get_stride(::AbstractResNetBlock, idxs::NTuple{2, Integer}) - return (idxs[1] == 1 || idxs[1] == 1) ? 2 : 1 -end +get_stride(idxs::NTuple{2, Int}) = (idxs[1] == 1 || idxs[2] != 1) ? 1 : 2 -# returns `DropBlock`s for each stage of the ResNet +# returns `DropBlock`s for each stage of the ResNet as in timm. +# TODO - add experimental options for DropBlock as part of the API function _drop_blocks(drop_block_rate::AbstractFloat) return [ identity, identity, - DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00) + DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00), ] end -function _make_layers(block::basicblock, norm_layer, stride) - first_planes = block.planes ÷ block.reduction_factor - outplanes = block.planes * expansion_factor(block) - conv_bn1 = Chain(Conv((3, 3), block.inplanes => first_planes; stride, pad = 1, bias = false), - norm_layer(first_planes)) - conv_bn2 = Chain(Conv((3, 3), first_planes => outplanes; pad = 1, bias = false), - norm_layer(outplanes)) - layers = [] - push!(layers, conv_bn1, conv_bn2) - return layers -end - -function _make_layers(block::bottleneck, norm_layer, stride) - width = fld(block.planes * block.base_width, 64) * block.cardinality - first_planes = width ÷ block.reduction_factor - outplanes = block.planes * expansion_factor(block) - conv_bn1 = Chain(Conv((1, 1), block.inplanes => first_planes; bias = false), - norm_layer(first_planes)) - conv_bn2 = Chain(Conv((3, 3), first_planes => width; stride, pad = 1, - groups = block.cardinality, bias = false), - norm_layer(width)) - conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) - layers = [] - push!(layers, conv_bn1, conv_bn2, conv_bn3) - return layers -end - -function make_block(block::T, idxs::NTuple{2, Integer}; kwargs...) where {T <: AbstractResNetBlock} - stage_idx, block_idx = idxs - kwargs = Dict(kwargs) - stride = get(kwargs, :stride_fn, get_stride)(block, idxs) - expansion = expansion_factor(block) - norm_layer = get(kwargs, :norm_layer, BatchNorm) - layers = _make_layers(block, norm_layer, stride) - activation = get(kwargs, :activation, relu) - insert!(layers, 2, activation) - if T <: bottleneck - insert!(layers, 4, activation) - end - if haskey(kwargs, :drop_block_rate) - layer_idx = T <: basicblock ? 2 : 3 - dropblock = _drop_blocks(kwargs[:drop_block_rate])[stage_idx] - insert!(layers, layer_idx, dropblock) - end - if haskey(kwargs, :attn_fn) - attn_layer = kwargs[:attn_fn](block.planes) - push!(layers, attn_layer) - end - if haskey(kwargs, :drop_path_rate) - droppath = DropPath(kwargs[:droppath_rates][block_idx]) - push!(layers, droppath) - end - if haskey(kwargs, :downsample_fns) - downsample_tup = kwargs[:downsample_fns][stage_idx] - downsample = downsample_block(downsample_tup, block.inplanes, block.planes, expansion; stride) - connection = get(kwargs, :connection, addact)$activation - return Parallel(connection, downsample, Chain(layers...)) - else - return Chain(layers...) - end -end - """ resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) @@ -188,13 +193,13 @@ on how to use this function. shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - - `replace_stem_pool`: Whether to replace the default 3x3 max pooling layer with a + - `replace_pool`: Whether to replace the default 3x3 max pooling layer with a 3x3 convolution with stride 2 and a normalisation layer. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, - replace_stem_pool::Bool = false, norm_layer = BatchNorm, activation = relu) + replace_pool::Bool = false, norm_layer = BatchNorm, activation = relu) @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem @@ -219,51 +224,98 @@ function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, end bn1 = norm_layer(inplanes, activation) # Stem pooling - if replace_stem_pool - stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, - bias = false), - norm_layer(inplanes, activation)) - else - stempool = MaxPool((3, 3); stride = 2, pad = 1) - end + stempool = replace_pool ? + Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, bias = false), + norm_layer(inplanes, activation)) : + MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool), inplanes end +function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, attn_fn = planes -> identity, kargs...) + return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor, + activation, norm_layer, attn_fn) +end + +function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64, + reduction_factor = 1, activation = relu, norm_layer = BatchNorm, + attn_fn = planes -> identity, kargs...) + return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width, + reduction_factor, activation, norm_layer, + attn_fn) +end + +function template_builder(downsample_fn::Union{typeof(downsample_conv), + typeof(downsample_pool), + typeof(downsample_identity)}; + norm_layer = BatchNorm) + return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer) +end + +function configure_block(block_template, layers::Vector{Int}; expansion, + downsample_templates::Vector, inplanes::Integer = 64, + drop_path_rate = 0.0, drop_block_rate = 0.0, kargs...) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(layers)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(layers)) + # closure over `idxs` + function get_layers(idxs::NTuple{2, Int}) + stage_idx, block_idx = idxs + planes = 64 * 2^(stage_idx - 1) + stride = get_stride(idxs) + downsample_fns = downsample_templates[stage_idx] + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_fns[1] : + downsample_fns[2] + schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = block_template(inplanes, planes; stride, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride) + # inplanes increases by expansion after each block + inplanes = (planes * expansion) + return ((block, downsample), inplanes) + end + return get_layers +end + # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function resnet_stages(block_type, channels, block_repeats, inplanes; kwargs...) +function resnet_stages(block_fn, block_repeats::Vector{Int}, inplanes::Integer; + downsample_vec::Vector, connection = addact, + activation = relu, kwargs...) + outplanes = 0 + # Configure block template + block_template = template_builder(block_fn; kwargs...) + downsample_templates = map(x -> template_builder.(x), downsample_vec) + get_layers = configure_block(block_template, block_repeats; inplanes, + downsample_templates, + expansion = expansion_factor(block_fn), kwargs...) + # Construct each stage stages = [] - kwargs = Dict(kwargs) - cardinality = get(kwargs, :cardinality, 1) - base_width = get(kwargs, :base_width, 64) - reduction_factor = get(kwargs, :reduction_factor, 1) - ## Construct each stage - for (stage_idx, (planes, num_blocks)) in enumerate(zip(channels, block_repeats)) - ## Construct the blocks for each stage + for (stage_idx, (num_blocks)) in enumerate(block_repeats) + # Construct the blocks for each stage blocks = [] - for block_idx in 1:num_blocks - block_struct = block_type(inplanes, planes, reduction_factor, base_width, cardinality) - block = make_block(block_struct, (stage_idx, block_idx); kwargs...) - inplanes = planes * expansion_factor(block_struct) + for block_idx in range(1, num_blocks) + layers, outplanes = get_layers((stage_idx, block_idx)) + block = Parallel(connection$activation, layers...) push!(blocks, block) end push!(stages, Chain(blocks...)) end - return Chain(stages...), inplanes + return Chain(stages...), outplanes end -function resnet(block_fn, layers, downsample_opt = :B; inchannels::Integer = 3, - nclasses::Integer = 1000, stem = first(resnet_stem(; inchannels)), - inplanes::Integer = 64, kwargs...) - kwargs = Dict(kwargs) - ## Feature Blocks - channels = collect(64 * 2^i for i in range(0, length(layers))) - downsample_fns = _make_downsample_fns(downsample_opt, layers) - stage_blocks, num_features = resnet_stages(block_fn, channels, layers, inplanes; downsample_fns, kwargs...) - ## Classifier head - # num_features = 512 * expansion_factor(block_fn) +function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; + inchannels::Integer = 3, nclasses::Integer = 1000, + stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64, + kwargs...) + # Feature Blocks + downsample_vec = _make_downsample_fns(downsample_opt, layers) + stage_blocks, num_features = resnet_stages(block_fn, layers, inplanes; downsample_vec, + kwargs...) + # Classifier head pool_layer = get(kwargs, :pool_layer, AdaptiveMeanPool((1, 1))) use_conv = get(kwargs, :use_conv, false) # Pooling diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 1de2f2195..26bfbb6c1 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -27,7 +27,8 @@ end function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, attn_fn = _ -> squeeze_excite) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, + attn_fn = planes -> squeeze_excite(planes)) if pretrain loadpretrain!(layers, string("SEResNet", depth)) end @@ -70,8 +71,8 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt inchannels = 3, nclasses = 1000) @assert depth in [50, 101, 152] "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width, - attn_fn = _ -> squeeze_excite) + layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width, + attn_fn = planes -> squeeze_excite(planes)) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 981a05efc..8ad47af4a 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -25,6 +25,8 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. """ +# TODO add experimental `DropBlock` options from timm such as gaussian noise and +# more precise `DropBlock` to deal with edges. function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale) where {T} H, W, _, _ = size(x) @@ -146,13 +148,3 @@ equivalent to `identity`. on the CPU. """ DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity - -""" - droppath_rates(drop_path_rate::AbstractFloat = 0.0; depth) - -Returns the drop path rates for a given depth using the linear scaling rule -((reference)[https://arxiv.org/abs/1603.09382]) -""" -function droppath_rates(drop_path_rate::AbstractFloat = 0.0; depth) - return LinRange{Float32}(0.0, drop_path_rate, depth) -end \ No newline at end of file diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 6d86947c9..8572adb9f 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -19,13 +19,15 @@ function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, activation = relu, gate_activation = sigmoid, norm_layer = planes -> identity, rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) - return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes => rd_planes), - norm_layer(rd_planes), - activation, - Conv((1, 1), rd_planes => inplanes), - norm_layer(inplanes), - gate_activation), .*) + layers = [AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes => rd_planes), + norm_layer(rd_planes), + activation, + Conv((1, 1), rd_planes => inplanes), + norm_layer(inplanes), + gate_activation] + filter!(x -> x !== identity, layers) + return SkipConnection(Chain(layers...), .*) end """ diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index ec74c79d6..ac66be28b 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -59,7 +59,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, embedplanes = 512, drop_path_rate = 0.0, depth = 12, nclasses = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) - dp_rates = droppath_rates(drop_path_rate; depth) + dp_rates = linear_scheduler(drop_path_rate; depth) layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], kwargs...) diff --git a/src/utilities.jl b/src/utilities.jl index 938c598f0..b420efd0b 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -61,3 +61,12 @@ function _maybe_big_show(io, model) show(io, model) end end + +""" + linear_scheduler(drop_path_rate = 0.0; start_value = 0.0, depth) + +Returns the dropout rates for a given depth using the linear scaling rule. +""" +function linear_scheduler(drop_rate = 0.0; depth, start_value = 0.0) + return LinRange{Float32}(start_value, drop_rate, depth) +end From 13ed5ac766d49cbbe3747278aa8976a441e9b341 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 22 Jul 2022 14:03:58 +0530 Subject: [PATCH 49/64] Allow `prenorm` Simplify `conv_bn` to `conv_norm` and use it --- src/convnets/convmixer.jl | 12 +- src/convnets/densenet.jl | 12 +- src/convnets/efficientnet.jl | 8 +- src/convnets/inception.jl | 268 +++++++++++++++++------------------ src/convnets/mobilenet.jl | 15 +- src/convnets/resnets/core.jl | 79 ++++++----- src/convnets/vgg.jl | 2 +- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 87 +++++++----- src/layers/normalise.jl | 1 - 10 files changed, 251 insertions(+), 235 deletions(-) diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index d36f1a8d5..f3566e278 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -17,12 +17,12 @@ Creates a ConvMixer model. """ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000) - stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, - stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation; - preact = true, groups = planes, - pad = SamePad())), +), - conv_bn((1, 1), planes, planes, activation; preact = true)...) + stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, + stride = patch_size[1]) + blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; + preact = true, groups = planes, + pad = SamePad())), +), + conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses)) return Chain(Chain(stem..., Chain(blocks)), head) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 9da4e08b2..878e3fc8d 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -12,10 +12,10 @@ Create a Densenet bottleneck layer """ function dense_bottleneck(inplanes, outplanes) inner_channels = 4 * outplanes - return SkipConnection(Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, - rev = true)..., - conv_bn((3, 3), inner_channels, outplanes; pad = 1, - bias = false, rev = true)...), + return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, + prenorm = true)..., + conv_norm((3, 3), inner_channels, outplanes; pad = 1, + bias = false, prenorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes, outplanes) - return Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, prenorm = true)..., MeanPool((2, 2))) end @@ -70,7 +70,7 @@ Create a DenseNet model """ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000) layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)) + append!(layers, conv_norm((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)) push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 7236d5d44..aeacc8092 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -29,8 +29,8 @@ function efficientnet(scalings, block_config; scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) out_channels = _round_channels(scalew(32), 8) - stem = conv_bn((3, 3), inchannels, out_channels, swish; - bias = false, stride = 2, pad = SamePad()) + stem = conv_norm((3, 3), inchannels, out_channels, swish; + bias = false, stride = 2, pad = SamePad()) blocks = [] for (n, k, s, e, i, o) in block_config @@ -50,8 +50,8 @@ function efficientnet(scalings, block_config; blocks = Chain(blocks...) head_out_channels = _round_channels(max_width, 8) - head = conv_bn((1, 1), out_channels, head_out_channels, swish; - bias = false, pad = SamePad()) + head = conv_norm((1, 1), out_channels, head_out_channels, swish; + bias = false, pad = SamePad()) top = Dense(head_out_channels, nclasses) diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index ba9c935f6..4fd43f26c 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -12,14 +12,14 @@ Create an Inception-v3 style-A module - `pool_proj`: the number of output feature maps for the pooling projection """ function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)..., - conv_bn((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) + branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), inplanes, pool_proj)...) + conv_norm((1, 1), inplanes, pool_proj)...) return Parallel(cat_channels, branch1x1, branch5x5, branch3x3, branch_pool) end @@ -35,10 +35,10 @@ Create an Inception-v3 style-B module - `inplanes`: number of input feature maps """ function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; stride = 2)...) + branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3_1, branch3x3_2, branch_pool) @@ -57,17 +57,17 @@ Create an Inception-v3 style-C module - `n`: the "grid size" (kernel size) for the convolution layers """ function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)..., - conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)..., - conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_bn((1, n), inner_planes, 192; pad = (0, 3))...) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) + branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) + branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), inplanes, 192)...) + conv_norm((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, branch7x7_1, branch7x7_2, branch_pool) end @@ -83,12 +83,12 @@ Create an Inception-v3 style-D module - `inplanes`: number of input feature maps """ function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)..., - conv_bn((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)..., - conv_bn((1, 7), 192, 192; pad = (0, 3))..., - conv_bn((7, 1), 192, 192; pad = (3, 0))..., - conv_bn((3, 3), 192, 192; stride = 2)...) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((1, 7), 192, 192; pad = (0, 3))..., + conv_norm((7, 1), 192, 192; pad = (3, 0))..., + conv_norm((3, 3), 192, 192; stride = 2)...) branch_pool = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch3x3, branch7x7x3, branch_pool) @@ -105,16 +105,16 @@ Create an Inception-v3 style-E module - `inplanes`: number of input feature maps """ function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)..., - conv_bn((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) + branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) + branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., + conv_norm((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), inplanes, 192)...) + conv_norm((1, 1), inplanes, 192)...) return Parallel(cat_channels, branch1x1, Chain(branch3x3_1, @@ -136,12 +136,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes """ function inceptionv3(; nclasses = 1000) - layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., + layer = Chain(Chain(conv_norm((3, 3), 3, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_bn((1, 1), 64, 80)..., - conv_bn((3, 3), 80, 192)..., + conv_norm((1, 1), 64, 80)..., + conv_norm((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), inceptionv3_a(192, 32), inceptionv3_a(256, 64), @@ -200,83 +200,83 @@ classifier(m::Inceptionv3) = m.layers[2] function mixed_3a() return Parallel(cat_channels, MaxPool((3, 3); stride = 2), - Chain(conv_bn((3, 3), 64, 96; stride = 2)...)) + Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) end function mixed_4a() return Parallel(cat_channels, - Chain(conv_bn((1, 1), 160, 64)..., - conv_bn((3, 3), 64, 96)...), - Chain(conv_bn((1, 1), 160, 64)..., - conv_bn((1, 7), 64, 64; pad = (0, 3))..., - conv_bn((7, 1), 64, 64; pad = (3, 0))..., - conv_bn((3, 3), 64, 96)...)) + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((3, 3), 64, 96)...), + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((1, 7), 64, 64; pad = (0, 3))..., + conv_norm((7, 1), 64, 64; pad = (3, 0))..., + conv_norm((3, 3), 64, 96)...)) end function mixed_5a() return Parallel(cat_channels, - Chain(conv_bn((3, 3), 192, 192; stride = 2)...), + Chain(conv_norm((3, 3), 192, 192; stride = 2)...), MaxPool((3, 3); stride = 2)) end function inceptionv4_a() - branch1 = Chain(conv_bn((1, 1), 384, 96)...) - branch2 = Chain(conv_bn((1, 1), 384, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_bn((1, 1), 384, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_bn((1, 1), 384, 96)...) + branch1 = Chain(conv_norm((1, 1), 384, 96)...) + branch2 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function reduction_a() - branch1 = Chain(conv_bn((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 384, 192)..., - conv_bn((3, 3), 192, 224; pad = 1)..., - conv_bn((3, 3), 224, 256; stride = 2)...) + branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 384, 192)..., + conv_norm((3, 3), 192, 224; pad = 1)..., + conv_norm((3, 3), 224, 256; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function inceptionv4_b() - branch1 = Chain(conv_bn((1, 1), 1024, 384)...) - branch2 = Chain(conv_bn((1, 1), 1024, 192)..., - conv_bn((1, 7), 192, 224; pad = (0, 3))..., - conv_bn((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_bn((1, 1), 1024, 192)..., - conv_bn((7, 1), 192, 192; pad = (0, 3))..., - conv_bn((1, 7), 192, 224; pad = (3, 0))..., - conv_bn((7, 1), 224, 224; pad = (0, 3))..., - conv_bn((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_bn((1, 1), 1024, 128)...) + branch1 = Chain(conv_norm((1, 1), 1024, 384)...) + branch2 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((1, 7), 192, 224; pad = (0, 3))..., + conv_norm((7, 1), 224, 256; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((7, 1), 192, 192; pad = (0, 3))..., + conv_norm((1, 7), 192, 224; pad = (3, 0))..., + conv_norm((7, 1), 224, 224; pad = (0, 3))..., + conv_norm((1, 7), 224, 256; pad = (3, 0))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function reduction_b() - branch1 = Chain(conv_bn((1, 1), 1024, 192)..., - conv_bn((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 1024, 256)..., - conv_bn((1, 7), 256, 256; pad = (0, 3))..., - conv_bn((7, 1), 256, 320; pad = (3, 0))..., - conv_bn((3, 3), 320, 320; stride = 2)...) + branch1 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1024, 256)..., + conv_norm((1, 7), 256, 256; pad = (0, 3))..., + conv_norm((7, 1), 256, 320; pad = (3, 0))..., + conv_norm((3, 3), 320, 320; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function inceptionv4_c() - branch1 = Chain(conv_bn((1, 1), 1536, 256)...) - branch2 = Chain(conv_bn((1, 1), 1536, 384)..., + branch1 = Chain(conv_norm((1, 1), 1536, 256)...) + branch2 = Chain(conv_norm((1, 1), 1536, 384)..., Parallel(cat_channels, - Chain(conv_bn((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_bn((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_bn((1, 1), 1536, 384)..., - conv_bn((3, 1), 384, 448; pad = (1, 0))..., - conv_bn((1, 3), 448, 512; pad = (0, 1))..., + Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) + branch3 = Chain(conv_norm((1, 1), 1536, 384)..., + conv_norm((3, 1), 384, 448; pad = (1, 0))..., + conv_norm((1, 3), 448, 512; pad = (0, 1))..., Parallel(cat_channels, - Chain(conv_bn((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_bn((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_bn((1, 1), 1536, 256)...) + Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end @@ -293,9 +293,9 @@ Create an Inceptionv4 model. - `nclasses`: the number of output classes. """ function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., mixed_3a(), mixed_4a(), mixed_5a(), @@ -360,66 +360,66 @@ classifier(m::Inceptionv4) = m.layers[2] ## Inception-ResNetv2 function mixed_5b() - branch1 = Chain(conv_bn((1, 1), 192, 96)...) - branch2 = Chain(conv_bn((1, 1), 192, 48)..., - conv_bn((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_bn((1, 1), 192, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) + branch1 = Chain(conv_norm((1, 1), 192, 96)...) + branch2 = Chain(conv_norm((1, 1), 192, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(conv_norm((1, 1), 192, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), 192, 64)...) + conv_norm((1, 1), 192, 64)...) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block35(scale = 1.0f0) - branch1 = Chain(conv_bn((1, 1), 320, 32)...) - branch2 = Chain(conv_bn((1, 1), 320, 32)..., - conv_bn((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_bn((1, 1), 320, 32)..., - conv_bn((3, 3), 32, 48; pad = 1)..., - conv_bn((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_bn((1, 1), 128, 320)...) + branch1 = Chain(conv_norm((1, 1), 320, 32)...) + branch2 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 48; pad = 1)..., + conv_norm((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(conv_norm((1, 1), 128, 320)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), branch4, inputscale(scale; activation = relu)), +) end function mixed_6a() - branch1 = Chain(conv_bn((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 320, 256)..., - conv_bn((3, 3), 256, 256; pad = 1)..., - conv_bn((3, 3), 256, 384; stride = 2)...) + branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 320, 256)..., + conv_norm((3, 3), 256, 256; pad = 1)..., + conv_norm((3, 3), 256, 384; stride = 2)...) branch3 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3) end function block17(scale = 1.0f0) - branch1 = Chain(conv_bn((1, 1), 1088, 192)...) - branch2 = Chain(conv_bn((1, 1), 1088, 128)..., - conv_bn((1, 7), 128, 160; pad = (0, 3))..., - conv_bn((7, 1), 160, 192; pad = (3, 0))...) - branch3 = Chain(conv_bn((1, 1), 384, 1088)...) + branch1 = Chain(conv_norm((1, 1), 1088, 192)...) + branch2 = Chain(conv_norm((1, 1), 1088, 128)..., + conv_norm((1, 7), 128, 160; pad = (0, 3))..., + conv_norm((7, 1), 160, 192; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 384, 1088)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation = relu)), +) end function mixed_7a() - branch1 = Chain(conv_bn((1, 1), 1088, 256)..., - conv_bn((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 1088, 256)..., - conv_bn((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_bn((1, 1), 1088, 256)..., - conv_bn((3, 3), 256, 288; pad = 1)..., - conv_bn((3, 3), 288, 320; stride = 2)...) + branch1 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; pad = 1)..., + conv_norm((3, 3), 288, 320; stride = 2)...) branch4 = MaxPool((3, 3); stride = 2) return Parallel(cat_channels, branch1, branch2, branch3, branch4) end function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_bn((1, 1), 2080, 192)...) - branch2 = Chain(conv_bn((1, 1), 2080, 192)..., - conv_bn((1, 3), 192, 224; pad = (0, 1))..., - conv_bn((3, 1), 224, 256; pad = (1, 0))...) - branch3 = Chain(conv_bn((1, 1), 448, 2080)...) + branch1 = Chain(conv_norm((1, 1), 2080, 192)...) + branch2 = Chain(conv_norm((1, 1), 2080, 192)..., + conv_norm((1, 3), 192, 224; pad = (0, 1))..., + conv_norm((3, 1), 224, 256; pad = (1, 0))...) + branch3 = Chain(conv_norm((1, 1), 448, 2080)...) return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), branch3, inputscale(scale; activation)), +) end @@ -437,12 +437,12 @@ Creates an InceptionResNetv2 model. - `nclasses`: the number of output classes. """ function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., MaxPool((3, 3); stride = 2), - conv_bn((3, 3), 64, 80)..., - conv_bn((3, 3), 80, 192)..., + conv_norm((3, 3), 64, 80)..., + conv_norm((3, 3), 80, 192)..., MaxPool((3, 3); stride = 2), mixed_5b(), [block35(0.17f0) for _ in 1:10]..., @@ -451,7 +451,7 @@ function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000 mixed_7a(), [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), - conv_bn((1, 1), 2080, 1536)...) + conv_norm((1, 1), 2080, 1536)...) head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(1536, nclasses)) return Chain(body, head) @@ -516,8 +516,8 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, grow_at_start = true) if outchannels != inchannels || stride != 1 - skip = conv_bn((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, + bias = false) else skip = [identity] end @@ -554,8 +554,8 @@ Creates an Xception model. - `nclasses`: the number of output classes. """ function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_bn((3, 3), 32, 64; bias = false)..., + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., + conv_norm((3, 3), 32, 64; bias = false)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), xception_block(128, 256, 2; stride = 2), xception_block(256, 728, 2; stride = 2), diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 15dc037e8..06274c94e 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -37,8 +37,9 @@ function mobilenetv1(width_mult, config; layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1, bias = false) : - conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1, - bias = false) + conv_norm((3, 3), inchannels, outch, activation; stride = stride, + pad = 1, + bias = false) append!(layers, layer) inchannels = outch end @@ -131,7 +132,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla # building first layer inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) layers = [] - append!(layers, conv_bn((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) @@ -147,7 +148,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : max_width return Chain(Chain(Chain(layers), - conv_bn((1, 1), inplanes, outplanes, relu6; bias = false)...), + conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)...), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses))) end @@ -235,8 +236,8 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla inplanes = _round_channels(16 * width_mult, 8) layers = [] append!(layers, - conv_bn((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, - bias = false)) + conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, + bias = false)) explanes = 0 # building inverted residual blocks for (k, t, c, r, a, s) in configs @@ -256,7 +257,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla Dropout(0.2), Dense(output_channel, nclasses)) return Chain(Chain(Chain(layers), - conv_bn((1, 1), inplanes, explanes, hardswish; bias = false)...), + conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)...), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 284746a60..6fdc85664 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -31,15 +31,16 @@ Creates a basic ResNet block. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, + norm_layer = BatchNorm, prenorm = false, + drop_block = identity, drop_path = identity, attn_fn = planes -> identity) expansion = expansion_factor(basicblock) first_planes = planes ÷ reduction_factor outplanes = planes * expansion - conv_bn1 = [Conv((3, 3), inplanes => first_planes; stride, pad = 1, bias = false), - norm_layer(first_planes)] - conv_bn2 = [Conv((3, 3), first_planes => outplanes; pad = 1, bias = false), - norm_layer(outplanes)] + conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm, + stride, pad = 1, bias = false) + conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm, + pad = 1, bias = false) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] filter!(x -> x !== identity, layers) @@ -79,21 +80,20 @@ Creates a bottleneck ResNet block. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, - reduction_factor = 1, activation = relu, norm_layer = BatchNorm, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = 4 + expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * expansion - conv_bn1 = [Conv((1, 1), inplanes => first_planes; bias = false), - norm_layer(first_planes, activation)] - conv_bn2 = [ - Conv((3, 3), first_planes => width; stride, pad = 1, - groups = cardinality, - bias = false), - norm_layer(width)] - conv_bn3 = [Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)] + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm, + bias = false) + conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm, + stride, pad = 1, groups = cardinality, bias = false) + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, prenorm, + bias = false) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] filter!(x -> x !== identity, layers) @@ -103,18 +103,18 @@ expansion_factor(::typeof(bottleneck)) = 4 # Downsample layer using convolutions. function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm) - return Chain(Conv((1, 1), inplanes => outplanes; stride, pad = SamePad(), bias = false), - norm_layer(outplanes)) + norm_layer = BatchNorm, prenorm = false) + return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, prenorm, + pad = SamePad(), stride, bias = false)...) end # Downsample layer using max pooling function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm) + norm_layer = BatchNorm, prenorm = false) pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - Conv((1, 1), inplanes => outplanes; bias = false), - norm_layer(outplanes)) + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, prenorm, + bias = false)...) end # Downsample layer which is an identity projection. Uses max pooling @@ -198,8 +198,9 @@ on how to use this function. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ -function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, - replace_pool::Bool = false, norm_layer = BatchNorm, activation = relu) +function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, + replace_pool::Bool = false, norm_layer = BatchNorm, prenorm = false, + activation = relu) @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem @@ -212,12 +213,10 @@ function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, elseif stem_type == :deep_tiered stem_channels = (3 * (stem_width ÷ 4), stem_width) end - conv1 = Chain(Conv((3, 3), inchannels => stem_channels[1]; stride = 2, pad = 1, - bias = false), - norm_layer(stem_channels[1], activation), - Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, - bias = false), - norm_layer(stem_channels[2], activation), + conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; + norm_layer, prenorm, stride = 2, pad = 1, bias = false)..., + conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; + norm_layer, pad = 1, bias = false)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) @@ -225,31 +224,34 @@ function resnet_stem(; stem_type::Symbol = :default, inchannels::Integer = 3, bn1 = norm_layer(inplanes, activation) # Stem pooling stempool = replace_pool ? - Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, bias = false), - norm_layer(inplanes, activation)) : + Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, + prenorm, + stride = 2, pad = 1, bias = false)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool), inplanes end function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, attn_fn = planes -> identity, kargs...) + norm_layer = BatchNorm, prenorm = false, + attn_fn = planes -> identity, kargs...) return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor, - activation, norm_layer, attn_fn) + activation, norm_layer, prenorm, attn_fn) end function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64, - reduction_factor = 1, activation = relu, norm_layer = BatchNorm, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, attn_fn = planes -> identity, kargs...) return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width, - reduction_factor, activation, norm_layer, - attn_fn) + reduction_factor, activation, + norm_layer, prenorm, attn_fn) end function template_builder(downsample_fn::Union{typeof(downsample_conv), typeof(downsample_pool), typeof(downsample_identity)}; - norm_layer = BatchNorm) - return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer) + norm_layer = BatchNorm, prenorm = false) + return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm) end function configure_block(block_template, layers::Vector{Int}; expansion, @@ -266,6 +268,7 @@ function configure_block(block_template, layers::Vector{Int}; expansion, downsample_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_fns[1] : downsample_fns[2] + # DropBlock, DropPath both take in rates based on a linear scaling schedule schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 957a0a483..c8a2e6344 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -17,7 +17,7 @@ function vgg_block(ifilters, ofilters, depth, batchnorm) layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_bn(k, ifilters, ofilters; pad = p, bias = false)) + append!(layers, conv_norm(k, ifilters, ofilters; pad = p, bias = false)) else push!(layers, Conv(k, ifilters => ofilters, relu; pad = p)) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 4269411ef..bb4a1ef2b 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -26,7 +26,7 @@ include("normalise.jl") export prenorm, ChannelLayerNorm include("conv.jl") -export conv_bn, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection +export conv_norm, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection include("drop.jl") export DropBlock, DropPath, droppath_rates diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7605a6cd1..b28507d6c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,35 +1,39 @@ """ - conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, - groups = 1, [bias, weight, init]) + conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; + norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true, + stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init]) Create a convolution + batch normalization pair with activation. # Arguments - - `kernelsize`: size of the convolution kernel (tuple) + - `kernel_size`: size of the convolution kernel (tuple) - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `rev`: set to `true` to place the batch norm before the convolution + - `norm_layer`: the normalization layer used + - `prenorm`: set to `true` to place the batch norm before the convolution - `preact`: set to `true` to place the activation function before the batch norm - (only compatible with `rev = false`) + (only compatible with `prenorm = false`) - `use_bn`: set to `false` to disable batch normalization - (only compatible with `rev = false` and `preact = false`) + (only compatible with `prenorm = false` and `preact = false`) - `stride`: stride of the convolution kernel - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, kwargs...) +function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; + norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true, + kwargs...) if !use_bn - (preact || rev) ? throw("preact only supported with `use_bn = true`") : - return [Conv(kernelsize, inplanes => outplanes, activation; kwargs...)] + if (preact || prenorm) + throw(ArgumentError("`preact` only supported with `use_bn = true`")) + else + return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] + end end - layers = [] - if rev + if prenorm activations = (conv = activation, bn = identity) bnplanes = inplanes else @@ -37,50 +41,59 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; bnplanes = outplanes end if preact - rev ? throw(ArgumentError("preact and rev cannot be set at the same time")) : - activations = (conv = activation, bn = identity) + if prenorm + throw(ArgumentError("`preact` and `prenorm` cannot be set at the same time")) + else + activations = (conv = activation, bn = identity) + end end - push!(layers, - Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...)) - push!(layers, - BatchNorm(Int(bnplanes), activations.bn)) - return rev ? reverse(layers) : layers + layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), + norm_layer(bnplanes, activations.bn)] + return prenorm ? reverse(layers) : layers +end + +function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes, + activation = identity; kwargs...) + inplanes, outplanes = ch + return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end """ - depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, use_bn = (true, true), + depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu; + prenorm = false, use_bn = (true, true), stride = 1, pad = 0, dilation = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: - - a `kernelsize` depthwise convolution from `inplanes => inplanes` + - a `kernel_size` depthwise convolution from `inplanes => inplanes` - a batch norm layer + `activation` (if `use_bn[1] == true`; otherwise `activation` is applied to the convolution output) - - a `kernelsize` convolution from `inplanes => outplanes` + - a `kernel_size` convolution from `inplanes => outplanes` - a batch norm layer + `activation` (if `use_bn[2] == true`; otherwise `activation` is applied to the convolution output) See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). # Arguments - - `kernelsize`: size of the convolution kernel (tuple) + - `kernel_size`: size of the convolution kernel (tuple) - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `rev`: set to `true` to place the batch norm before the convolution + - `prenorm`: set to `true` to place the batch norm before the convolution - `use_bn`: a tuple of two booleans to specify whether to use batch normalization for the first and second convolution - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, use_bn = (true, true), +function depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu; + prenorm = false, use_bn = (true, true), stride = 1, kwargs...) - return vcat(conv_bn(kernelsize, inplanes, inplanes, activation; - rev, use_bn = use_bn[1], stride, groups = Int(inplanes), kwargs...), - conv_bn((1, 1), inplanes, outplanes, activation; rev, use_bn = use_bn[2])) + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; + prenorm, use_bn = use_bn[1], stride, groups = inplanes, + kwargs...), + conv_norm((1, 1), inplanes, outplanes, activation; prenorm, + use_bn = use_bn[2])) end """ @@ -97,8 +110,8 @@ Create a skip projection """ function skip_projection(inplanes, outplanes, downsample = false) return downsample ? - Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)) : - Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)) + Chain(conv_norm((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)) : + Chain(conv_norm((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)) end # array -> PaddedView(0, array, outplanes) for zero padding arrays @@ -151,15 +164,15 @@ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 conv1 = (inplanes == hidden_planes) ? identity : - Chain(conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false)) + Chain(conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false)) selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, norm_layer = BatchNorm) invres = Chain(conv1, - conv_bn(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad = pad, groups = hidden_planes)..., + conv_norm(kernel_size, hidden_planes, hidden_planes, activation; + bias = false, stride, pad = pad, groups = hidden_planes)..., selayer, - conv_bn((1, 1), hidden_planes, outplanes, identity; bias = false)...) + conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index e71634e22..bb83f042d 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -16,7 +16,6 @@ struct ChannelLayerNorm{D, T} diag::D ϵ::T end - @functor ChannelLayerNorm function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) From 6c005d313bb0b03720f623c08488d882e7195c8d Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 23 Jul 2022 08:01:39 +0530 Subject: [PATCH 50/64] Cleanup --- src/convnets/resnets/core.jl | 53 ++++++++++++++---------------------- src/layers/Layers.jl | 2 +- src/layers/mlp.jl | 23 ++++++++++++---- src/layers/selayers.jl | 3 +- 4 files changed, 40 insertions(+), 41 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 6fdc85664..ad60814d4 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,5 +1,3 @@ -## It is recommended to check out the user guide for more information. - """ basicblock(inplanes, planes; stride = 1, downsample = identity, reduction_factor = 1, dilation = 1, first_dilation = dilation, @@ -43,8 +41,7 @@ function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activati pad = 1, bias = false) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] - filter!(x -> x !== identity, layers) - return Chain(layers...) + return Chain(filter!(!=(identity), layers)...) end expansion_factor(::typeof(basicblock)) = 1 @@ -96,8 +93,7 @@ function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = bias = false) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] - filter!(x -> x !== identity, layers) - return Chain(layers...) + return Chain(filter!(!=(identity), layers)...) end expansion_factor(::typeof(bottleneck)) = 4 @@ -263,11 +259,12 @@ function configure_block(block_template, layers::Vector{Int}; expansion, function get_layers(idxs::NTuple{2, Int}) stage_idx, block_idx = idxs planes = 64 * 2^(stage_idx - 1) + # `get_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper. stride = get_stride(idxs) downsample_fns = downsample_templates[stage_idx] downsample_fn = (stride != 1 || inplanes != planes * expansion) ? - downsample_fns[1] : - downsample_fns[2] + downsample_fns[1] : downsample_fns[2] # DropBlock, DropPath both take in rates based on a linear scaling schedule schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) @@ -285,16 +282,9 @@ end # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function resnet_stages(block_fn, block_repeats::Vector{Int}, inplanes::Integer; - downsample_vec::Vector, connection = addact, - activation = relu, kwargs...) +function resnet_stages(get_layers, block_repeats::Vector{Int}, inplanes::Integer; + connection = addact, activation = relu, kwargs...) outplanes = 0 - # Configure block template - block_template = template_builder(block_fn; kwargs...) - downsample_templates = map(x -> template_builder.(x), downsample_vec) - get_layers = configure_block(block_template, block_repeats; inplanes, - downsample_templates, - expansion = expansion_factor(block_fn), kwargs...) # Construct each stage stages = [] for (stage_idx, (num_blocks)) in enumerate(block_repeats) @@ -313,24 +303,21 @@ end function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; inchannels::Integer = 3, nclasses::Integer = 1000, stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64, + pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false, dropout_rate = 0.0, kwargs...) - # Feature Blocks + # Configure downsample templates downsample_vec = _make_downsample_fns(downsample_opt, layers) - stage_blocks, num_features = resnet_stages(block_fn, layers, inplanes; downsample_vec, - kwargs...) - # Classifier head - pool_layer = get(kwargs, :pool_layer, AdaptiveMeanPool((1, 1))) - use_conv = get(kwargs, :use_conv, false) - # Pooling - if pool_layer === identity - @assert use_conv - "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" - end - flatten_in_pool = !use_conv && pool_layer !== identity - global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer - # Fully-connected layer - fc = create_fc(num_features, nclasses; use_conv) - classifier = Chain(global_pool, Dropout(get(kwargs, :dropout_rate, 0)), fc) + downsample_templates = map(x -> template_builder.(x), downsample_vec) + # Configure block templates + block_template = template_builder(block_fn; kwargs...) + get_layers = configure_block(block_template, layers; inplanes, + downsample_templates, + expansion = expansion_factor(block_fn), kwargs...) + # Build stages of the ResNet + stage_blocks, num_features = resnet_stages(get_layers, layers, inplanes; kwargs...) + # Build the classifier head + classifier = create_classifier(num_features, nclasses; dropout_rate, pool_layer, + use_conv) return Chain(Chain(stem, stage_blocks), classifier) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index bb4a1ef2b..9b5aa588b 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -20,7 +20,7 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mlp.jl") -export mlp_block, gated_mlp_block, create_fc +export mlp_block, gated_mlp_block, create_fc, create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index f72520451..4a623c977 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -47,17 +47,30 @@ end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) """ - create_fc(inplanes, nclasses; use_conv = false) + create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = 0.0, use_conv = false) -Creates a classifier head to be used for models. Uses `SelectAdaptivePool` for the pooling layer. +Creates a classifier head to be used for models. # Arguments - `inplanes`: number of input feature maps - `nclasses`: number of output classes + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. """ -function create_fc(inplanes, nclasses; use_conv = false) - return use_conv ? Conv((1, 1), inplanes => nclasses; bias = true) : - Dense(inplanes => nclasses; bias = true) +function create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = 0.0, use_conv = false) + # Pooling + if pool_layer === identity + @assert use_conv + "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" + end + flatten_in_pool = !use_conv && pool_layer !== identity + global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer + # Fully-connected layer + fc = use_conv ? Conv((1, 1), inplanes => nclasses) : Dense(inplanes => nclasses) + return Chain(global_pool, Dropout(dropout_rate), fc) end diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 8572adb9f..db0f3715d 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -26,8 +26,7 @@ function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, Conv((1, 1), rd_planes => inplanes), norm_layer(inplanes), gate_activation] - filter!(x -> x !== identity, layers) - return SkipConnection(Chain(layers...), .*) + return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) end """ From bd443f1f5bfdfcbd61211e02c56f6eb1bdeaa756 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 23 Jul 2022 11:16:37 +0530 Subject: [PATCH 51/64] Reorganisation And some more formatting --- .github/workflows/CI.yml | 3 +- src/Metalhead.jl | 34 +- src/convnets/alexnet.jl | 4 +- src/convnets/convmixer.jl | 26 +- src/convnets/convnext.jl | 25 +- src/convnets/densenet.jl | 14 +- src/convnets/efficientnet.jl | 33 +- src/convnets/inception.jl | 605 -------------------- src/convnets/{ => inception}/googlenet.jl | 6 +- src/convnets/inception/inceptionresnetv2.jl | 133 +++++ src/convnets/inception/inceptionv3.jl | 196 +++++++ src/convnets/inception/inceptionv4.jl | 158 +++++ src/convnets/inception/xception.jl | 106 ++++ src/convnets/mobilenet.jl | 339 ----------- src/convnets/mobilenet/mobilenetv1.jl | 102 ++++ src/convnets/mobilenet/mobilenetv2.jl | 97 ++++ src/convnets/mobilenet/mobilenetv3.jl | 129 +++++ src/convnets/resnets/core.jl | 16 +- src/convnets/resnets/resnet.jl | 15 +- src/convnets/resnets/resnext.jl | 5 +- src/convnets/resnets/seresnet.jl | 12 +- src/convnets/squeezenet.jl | 43 +- src/convnets/vgg.jl | 20 +- src/mixers/core.jl | 43 ++ src/mixers/gmlp.jl | 110 ++++ src/mixers/mlpmixer.jl | 69 +++ src/mixers/resmlp.jl | 72 +++ src/other/mlpmixer.jl | 302 ---------- src/utilities.jl | 8 +- src/vit-based/vit.jl | 5 +- test/convnets.jl | 4 +- test/{other.jl => mixers.jl} | 6 +- test/runtests.jl | 8 +- test/{vit-based.jl => vits.jl} | 2 +- 34 files changed, 1345 insertions(+), 1405 deletions(-) delete mode 100644 src/convnets/inception.jl rename src/convnets/{ => inception}/googlenet.jl (99%) create mode 100644 src/convnets/inception/inceptionresnetv2.jl create mode 100644 src/convnets/inception/inceptionv3.jl create mode 100644 src/convnets/inception/inceptionv4.jl create mode 100644 src/convnets/inception/xception.jl delete mode 100644 src/convnets/mobilenet.jl create mode 100644 src/convnets/mobilenet/mobilenetv1.jl create mode 100644 src/convnets/mobilenet/mobilenetv2.jl create mode 100644 src/convnets/mobilenet/mobilenetv3.jl create mode 100644 src/mixers/core.jl create mode 100644 src/mixers/gmlp.jl create mode 100644 src/mixers/mlpmixer.jl create mode 100644 src/mixers/resmlp.jl delete mode 100644 src/other/mlpmixer.jl rename test/{other.jl => mixers.jl} (76%) rename test/{vit-based.jl => vits.jl} (65%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e148bdff2..32882857b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,9 +34,8 @@ jobs: - '["InceptionResNetv2", "Xception"]' - '"DenseNet"' - '["ConvNeXt", "ConvMixer"]' - # - '"ConvMixer"' - '"ViT"' - - '"Other"' + - '"Mixers"' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 67731825e..3c8469dd2 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -21,26 +21,38 @@ using .Layers # CNN models include("convnets/alexnet.jl") include("convnets/vgg.jl") -include("convnets/inception.jl") -include("convnets/googlenet.jl") -include("convnets/densenet.jl") -include("convnets/squeezenet.jl") -include("convnets/mobilenet.jl") -include("convnets/efficientnet.jl") -include("convnets/convnext.jl") -include("convnets/convmixer.jl") ## ResNets include("convnets/resnets/core.jl") include("convnets/resnets/resnet.jl") include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") +## Inceptions +include("convnets/inception/googlenet.jl") +include("convnets/inception/inceptionv3.jl") +include("convnets/inception/inceptionv4.jl") +include("convnets/inception/inceptionresnetv2.jl") +include("convnets/inception/xception.jl") +## MobileNets +include("convnets/mobilenet/mobilenetv1.jl") +include("convnets/mobilenet/mobilenetv2.jl") +include("convnets/mobilenet/mobilenetv3.jl") +## Others +include("convnets/densenet.jl") +include("convnets/squeezenet.jl") +include("convnets/efficientnet.jl") +include("convnets/convnext.jl") +include("convnets/convmixer.jl") -# Other models -include("other/mlpmixer.jl") +# Mixers +include("mixers/core.jl") +include("mixers/mlpmixer.jl") +include("mixers/resmlp.jl") +include("mixers/gmlp.jl") -# ViT-based models +# ViTs include("vit-based/vit.jl") +# Load pretrained weights include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 87f2c288e..8ff65ffef 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -24,7 +24,6 @@ function alexnet(; nclasses = 1000) Dropout(0.5), Dense(4096, 4096, relu), Dense(4096, nclasses))) - return layers end @@ -46,6 +45,7 @@ See also [`alexnet`](#). struct AlexNet layers::Any end +@functor AlexNet function AlexNet(; pretrain = false, nclasses = 1000) layers = alexnet(; nclasses = nclasses) @@ -55,8 +55,6 @@ function AlexNet(; pretrain = false, nclasses = 1000) return AlexNet(layers) end -@functor AlexNet - (m::AlexNet)(x) = m.layers(x) backbone(m::AlexNet) = m.layers[1] diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index f3566e278..6547ba4fb 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -28,13 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), return Chain(Chain(stem..., Chain(blocks)), head) end -convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9), - :patch_size => (7, 7)), - :small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7), - :patch_size => (7, 7)), - :large => Dict(:planes => 1024, :depth => 20, +convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9), - :patch_size => (7, 7))) + :patch_size => (7, 7)), + :small => Dict(:planes => 768, :depth => 32, + :kernel_size => (7, 7), + :patch_size => (7, 7)), + :large => Dict(:planes => 1024, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7))) """ ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) @@ -52,19 +54,19 @@ Creates a ConvMixer model. struct ConvMixer layers::Any end +@functor ConvMixer function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) - planes = convmixer_config[mode][:planes] - depth = convmixer_config[mode][:depth] - kernel_size = convmixer_config[mode][:kernel_size] - patch_size = convmixer_config[mode][:patch_size] + _checkconfig(mode, keys(convmixer_configs)) + planes = convmixer_configs[mode][:planes] + depth = convmixer_configs[mode][:depth] + kernel_size = convmixer_configs[mode][:kernel_size] + patch_size = convmixer_configs[mode][:patch_size] layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses) return ConvMixer(layers) end -@functor ConvMixer - (m::ConvMixer)(x) = m.layers(x) backbone(m::ConvMixer) = m.layers[1] diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 383e5d128..052192fec 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -66,20 +66,16 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 end # Configurations for ConvNeXt models -convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3], - :planes => [96, 192, 384, 768]), - :small => Dict(:depths => [3, 3, 27, 3], - :planes => [96, 192, 384, 768]), - :base => Dict(:depths => [3, 3, 27, 3], - :planes => [128, 256, 512, 1024]), - :large => Dict(:depths => [3, 3, 27, 3], - :planes => [192, 384, 768, 1536]), - :xlarge => Dict(:depths => [3, 3, 27, 3], - :planes => [256, 512, 1024, 2048])) +convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), + :small => ([3, 3, 27, 3], [96, 192, 384, 768]), + :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), + :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), + :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) struct ConvNeXt layers::Any end +@functor ConvNeXt """ ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000) @@ -98,17 +94,12 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert mode in keys(convnext_configs) - "`size` must be one of $(collect(keys(convnext_configs)))" - depths = convnext_configs[mode][:depths] - planes = convnext_configs[mode][:planes] - layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses) + _checkconfig(mode, keys(convnext_configs)) + layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses) return ConvNeXt(layers) end (m::ConvNeXt)(x) = m.layers(x) -@functor ConvNeXt - backbone(m::ConvNeXt) = m.layers[1] classifier(m::ConvNeXt) = m.layers[2] diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 878e3fc8d..0c5bd6ad6 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -140,14 +140,14 @@ end backbone(m::DenseNet) = m.layers[1] classifier(m::DenseNet) = m.layers[2] -const densenet_config = Dict(121 => (6, 12, 24, 16), - 161 => (6, 12, 36, 24), - 169 => (6, 12, 32, 32), - 201 => (6, 12, 48, 32)) +const densenet_configs = Dict(121 => (6, 12, 24, 16), + 161 => (6, 12, 36, 24), + 169 => (6, 12, 32, 32), + 201 => (6, 12, 48, 32)) """ DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - DenseNet(transition_config::NTuple{N,Integer}) + DenseNet(transition_configs::NTuple{N,Integer}) Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201) ([reference](https://arxiv.org/abs/1608.06993)). @@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. See also [`Metalhead.densenet`](#). """ function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - @assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))." - model = DenseNet(densenet_config[config]; nclasses = nclasses) + _checkconfig(config, keys(densenet_configs)) + model = DenseNet(densenet_configs[config]; nclasses = nclasses) if pretrain loadpretrain!(model, string("DenseNet", config)) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index aeacc8092..122fd512a 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -1,5 +1,5 @@ """ - efficientnet(scalings, block_config; + efficientnet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). @@ -8,7 +8,7 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `scalings`: global width and depth scaling (given as a tuple) - - `block_config`: configuration for each inverted residual block, + - `block_configs`: configuration for each inverted residual block, given as a vector of tuples with elements: + `n`: number of block repetitions (will be scaled by global depth scaling) @@ -22,22 +22,19 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `max_width`: maximum number of output channels before the fully connected classification blocks """ -function efficientnet(scalings, block_config; +function efficientnet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - out_channels = _round_channels(scalew(32), 8) stem = conv_norm((3, 3), inchannels, out_channels, swish; bias = false, stride = 2, pad = SamePad()) - blocks = [] - for (n, k, s, e, i, o) in block_config + for (n, k, s, e, i, o) in block_configs in_channels = _round_channels(scalew(i), 8) out_channels = _round_channels(scalew(o), 8) repeats = scaled(n) - push!(blocks, invertedresidual(k, in_channels, in_channels * e, out_channels, swish; stride = s, reduction = 4)) @@ -48,13 +45,10 @@ function efficientnet(scalings, block_config; end end blocks = Chain(blocks...) - head_out_channels = _round_channels(max_width, 8) head = conv_norm((1, 1), out_channels, head_out_channels, swish; bias = false, pad = SamePad()) - top = Dense(head_out_channels, nclasses) - return Chain(Chain([stem..., blocks, head...]), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top)) end @@ -92,9 +86,10 @@ const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)), struct EfficientNet layers::Any end +@functor EfficientNet """ - EfficientNet(scalings, block_config; + EfficientNet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). @@ -104,7 +99,7 @@ See also [`efficientnet`](#). - `scalings`: global width and depth scaling (given as a tuple) - - `block_config`: configuration for each inverted residual block, + - `block_configs`: configuration for each inverted residual block, given as a vector of tuples with elements: + `n`: number of block repetitions (will be scaled by global depth scaling) @@ -118,17 +113,12 @@ See also [`efficientnet`](#). - `max_width`: maximum number of output channels before the fully connected classification blocks """ -function EfficientNet(scalings, block_config; +function EfficientNet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) - layers = efficientnet(scalings, block_config; - inchannels = inchannels, - nclasses = nclasses, - max_width = max_width) + layers = efficientnet(scalings, block_configs; inchannels, nclasses, max_width) return EfficientNet(layers) end -@functor EfficientNet - (m::EfficientNet)(x) = m.layers(x) backbone(m::EfficientNet) = m.layers[1] @@ -147,11 +137,8 @@ See also [`efficientnet`](#). - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) - @assert name in keys(efficientnet_global_configs) - "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" - + _checkconfig(name, keys(efficientnet_global_configs)) model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) pretrain && loadpretrain!(model, string("efficientnet-", name)) - return model end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl deleted file mode 100644 index 4fd43f26c..000000000 --- a/src/convnets/inception.jl +++ /dev/null @@ -1,605 +0,0 @@ -## Inceptionv3 - -""" - inceptionv3_a(inplanes, pool_proj) - -Create an Inception-v3 style-A module -(ref: Fig. 5 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps - - `pool_proj`: the number of output feature maps for the pooling projection -""" -function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, pool_proj)...) - return Parallel(cat_channels, - branch1x1, branch5x5, branch3x3, branch_pool) -end - -""" - inceptionv3_b(inplanes) - -Create an Inception-v3 style-B module -(ref: Fig. 10 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; stride = 2)...) - branch_pool = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, - branch3x3_1, branch3x3_2, branch_pool) -end - -""" - inceptionv3_c(inplanes, inner_planes, n = 7) - -Create an Inception-v3 style-C module -(ref: Fig. 6 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps - - `inner_planes`: the number of output feature maps within each branch - - `n`: the "grid size" (kernel size) for the convolution layers -""" -function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) - return Parallel(cat_channels, - branch1x1, branch7x7_1, branch7x7_2, branch_pool) -end - -""" - inceptionv3_d(inplanes) - -Create an Inception-v3 style-D module -(ref: [pytorch](https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/inception.py#L322)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((1, 7), 192, 192; pad = (0, 3))..., - conv_norm((7, 1), 192, 192; pad = (3, 0))..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch_pool = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, - branch3x3, branch7x7x3, branch_pool) -end - -""" - inceptionv3_e(inplanes) - -Create an Inception-v3 style-E module -(ref: Fig. 7 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., - conv_norm((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) - return Parallel(cat_channels, - branch1x1, - Chain(branch3x3_1, - Parallel(cat_channels, - branch3x3_1a, branch3x3_1b)), - Chain(branch3x3_2, - Parallel(cat_channels, - branch3x3_2a, branch3x3_2b)), - branch_pool) -end - -""" - inceptionv3(; nclasses = 1000) - -Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `nclasses`: the number of output classes -""" -function inceptionv3(; nclasses = 1000) - layer = Chain(Chain(conv_norm((3, 3), 3, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - MaxPool((3, 3); stride = 2), - conv_norm((1, 1), 64, 80)..., - conv_norm((3, 3), 80, 192)..., - MaxPool((3, 3); stride = 2), - inceptionv3_a(192, 32), - inceptionv3_a(256, 64), - inceptionv3_a(288, 64), - inceptionv3_b(288), - inceptionv3_c(768, 128), - inceptionv3_c(768, 160), - inceptionv3_c(768, 160), - inceptionv3_c(768, 192), - inceptionv3_d(768), - inceptionv3_e(1280), - inceptionv3_e(2048)), - Chain(AdaptiveMeanPool((1, 1)), - Dropout(0.2), - MLUtils.flatten, - Dense(2048, nclasses))) - return layer -end - -""" - Inceptionv3(; pretrain = false, nclasses = 1000) - -Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). -See also [`inceptionv3`](#). - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -!!! warning - - `Inceptionv3` does not currently support pretrained weights. -""" -struct Inceptionv3 - layers::Any -end - -function Inceptionv3(; pretrain = false, nclasses = 1000) - layers = inceptionv3(; nclasses = nclasses) - if pretrain - loadpretrain!(layers, "Inceptionv3") - end - return Inceptionv3(layers) -end - -@functor Inceptionv3 - -(m::Inceptionv3)(x) = m.layers(x) - -backbone(m::Inceptionv3) = m.layers[1] -classifier(m::Inceptionv3) = m.layers[2] - -## Inceptionv4 - -function mixed_3a() - return Parallel(cat_channels, - MaxPool((3, 3); stride = 2), - Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) -end - -function mixed_4a() - return Parallel(cat_channels, - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((3, 3), 64, 96)...), - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((1, 7), 64, 64; pad = (0, 3))..., - conv_norm((7, 1), 64, 64; pad = (3, 0))..., - conv_norm((3, 3), 64, 96)...)) -end - -function mixed_5a() - return Parallel(cat_channels, - Chain(conv_norm((3, 3), 192, 192; stride = 2)...), - MaxPool((3, 3); stride = 2)) -end - -function inceptionv4_a() - branch1 = Chain(conv_norm((1, 1), 384, 96)...) - branch2 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_a() - branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 384, 192)..., - conv_norm((3, 3), 192, 224; pad = 1)..., - conv_norm((3, 3), 224, 256; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_b() - branch1 = Chain(conv_norm((1, 1), 1024, 384)...) - branch2 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((1, 7), 192, 224; pad = (0, 3))..., - conv_norm((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((7, 1), 192, 192; pad = (0, 3))..., - conv_norm((1, 7), 192, 224; pad = (3, 0))..., - conv_norm((7, 1), 224, 224; pad = (0, 3))..., - conv_norm((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_b() - branch1 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1024, 256)..., - conv_norm((1, 7), 256, 256; pad = (0, 3))..., - conv_norm((7, 1), 256, 320; pad = (3, 0))..., - conv_norm((3, 3), 320, 320; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_c() - branch1 = Chain(conv_norm((1, 1), 1536, 256)...) - branch2 = Chain(conv_norm((1, 1), 1536, 384)..., - Parallel(cat_channels, - Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_norm((1, 1), 1536, 384)..., - conv_norm((3, 1), 384, 448; pad = (1, 0))..., - conv_norm((1, 3), 448, 512; pad = (0, 1))..., - Parallel(cat_channels, - Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -""" - inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Create an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - mixed_3a(), - mixed_4a(), - mixed_5a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - reduction_a(), # mixed_6a - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - reduction_b(), # mixed_7a - inceptionv4_c(), - inceptionv4_c(), - inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), - Dense(1536, nclasses)) - return Chain(body, head) -end - -""" - Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `Inceptionv4` does not currently support pretrained weights. -""" -struct Inceptionv4 - layers::Any -end - -function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, - nclasses = 1000) - layers = inceptionv4(; inchannels, dropout_rate, nclasses) - if pretrain - loadpretrain!(layers, "Inceptionv4") - end - return Inceptionv4(layers) -end - -@functor Inceptionv4 - -(m::Inceptionv4)(x) = m.layers(x) - -backbone(m::Inceptionv4) = m.layers[1] -classifier(m::Inceptionv4) = m.layers[2] - -## Inception-ResNetv2 - -function mixed_5b() - branch1 = Chain(conv_norm((1, 1), 192, 96)...) - branch2 = Chain(conv_norm((1, 1), 192, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_norm((1, 1), 192, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), 192, 64)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function block35(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 320, 32)...) - branch2 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 48; pad = 1)..., - conv_norm((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_norm((1, 1), 128, 320)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), - branch4, inputscale(scale; activation = relu)), +) -end - -function mixed_6a() - branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 320, 256)..., - conv_norm((3, 3), 256, 256; pad = 1)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function block17(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 1088, 192)...) - branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((1, 7), 128, 160; pad = (0, 3))..., - conv_norm((7, 1), 160, 192; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 384, 1088)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), - branch3, inputscale(scale; activation = relu)), +) -end - -function mixed_7a() - branch1 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; pad = 1)..., - conv_norm((3, 3), 288, 320; stride = 2)...) - branch4 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_norm((1, 1), 2080, 192)...) - branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((1, 3), 192, 224; pad = (0, 1))..., - conv_norm((3, 1), 224, 256; pad = (1, 0))...) - branch3 = Chain(conv_norm((1, 1), 448, 2080)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), - branch3, inputscale(scale; activation)), +) -end - -""" - inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an InceptionResNetv2 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - MaxPool((3, 3); stride = 2), - conv_norm((3, 3), 64, 80)..., - conv_norm((3, 3), 80, 192)..., - MaxPool((3, 3); stride = 2), - mixed_5b(), - [block35(0.17f0) for _ in 1:10]..., - mixed_6a(), - [block17(0.10f0) for _ in 1:20]..., - mixed_7a(), - [block8(0.20f0) for _ in 1:9]..., - block8(; activation = relu), - conv_norm((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), - Dense(1536, nclasses)) - return Chain(body, head) -end - -""" - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an InceptionResNetv2 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `InceptionResNetv2` does not currently support pretrained weights. -""" -struct InceptionResNetv2 - layers::Any -end - -function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, - nclasses = 1000) - layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) - if pretrain - loadpretrain!(layers, "InceptionResNetv2") - end - return InceptionResNetv2(layers) -end - -@functor InceptionResNetv2 - -(m::InceptionResNetv2)(x) = m.layers(x) - -backbone(m::InceptionResNetv2) = m.layers[1] -classifier(m::InceptionResNetv2) = m.layers[2] - -## Xception - -""" - xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, - grow_at_start = true) - -Create an Xception block. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `inchannels`: The number of channels in the input. - - `outchannels`: number of output channels. - - `nrepeats`: number of repeats of depthwise separable convolution layers. - - `stride`: stride by which to downsample the input. - - `start_with_relu`: if true, start the block with a ReLU activation. - - `grow_at_start`: if true, increase the number of channels at the first convolution. -""" -function xception_block(inchannels, outchannels, nrepeats; stride = 1, - start_with_relu = true, - grow_at_start = true) - if outchannels != inchannels || stride != 1 - skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) - else - skip = [identity] - end - layers = [] - for i in 1:nrepeats - if grow_at_start - inc = i == 1 ? inchannels : outchannels - outc = outchannels - else - inc = inchannels - outc = i == nrepeats ? outchannels : inchannels - end - push!(layers, relu) - append!(layers, - depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, - use_bn = (false, false))) - push!(layers, BatchNorm(outc)) - end - layers = start_with_relu ? layers : layers[2:end] - push!(layers, MaxPool((3, 3); stride = stride, pad = 1)) - return Parallel(+, Chain(skip...), Chain(layers...)) -end - -""" - xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an Xception model. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_norm((3, 3), 32, 64; bias = false)..., - xception_block(64, 128, 2; stride = 2, start_with_relu = false), - xception_block(128, 256, 2; stride = 2), - xception_block(256, 728, 2; stride = 2), - [xception_block(728, 728, 3) for _ in 1:8]..., - xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), - Dense(2048, nclasses)) - return Chain(body, head) -end - -struct Xception - layers::Any -end - -""" - Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an Xception model. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `Xception` does not currently support pretrained weights. -""" -function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - layers = xception(; inchannels, dropout_rate, nclasses) - if pretrain - loadpretrain!(layers, "xception") - end - return Xception(layers) -end - -@functor Xception - -(m::Xception)(x) = m.layers(x) - -backbone(m::Xception) = m.layers[1] -classifier(m::Xception) = m.layers[2] diff --git a/src/convnets/googlenet.jl b/src/convnets/inception/googlenet.jl similarity index 99% rename from src/convnets/googlenet.jl rename to src/convnets/inception/googlenet.jl index 946d0d7f7..8a88ca943 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/inception/googlenet.jl @@ -16,15 +16,12 @@ Create an inception module for use in GoogLeNet """ function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj) branch1 = Chain(Conv((1, 1), inplanes => out_1x1)) - branch2 = Chain(Conv((1, 1), inplanes => red_3x3), Conv((3, 3), red_3x3 => out_3x3; pad = 1)) - branch3 = Chain(Conv((1, 1), inplanes => red_5x5), Conv((5, 5), red_5x5 => out_5x5; pad = 2)) branch4 = Chain(MaxPool((3, 3); stride = 1, pad = 1), Conv((1, 1), inplanes => pool_proj)) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) end @@ -83,6 +80,7 @@ See also [`googlenet`](#). struct GoogLeNet layers::Any end +@functor GoogLeNet function GoogLeNet(; pretrain = false, nclasses = 1000) layers = googlenet(; nclasses = nclasses) @@ -92,8 +90,6 @@ function GoogLeNet(; pretrain = false, nclasses = 1000) return GoogLeNet(layers) end -@functor GoogLeNet - (m::GoogLeNet)(x) = m.layers(x) backbone(m::GoogLeNet) = m.layers[1] diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl new file mode 100644 index 000000000..4b4b78706 --- /dev/null +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -0,0 +1,133 @@ +function mixed_5b() + branch1 = Chain(conv_norm((1, 1), 192, 96)...) + branch2 = Chain(conv_norm((1, 1), 192, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(conv_norm((1, 1), 192, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), 192, 64)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function block35(scale = 1.0f0) + branch1 = Chain(conv_norm((1, 1), 320, 32)...) + branch2 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 48; pad = 1)..., + conv_norm((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(conv_norm((1, 1), 128, 320)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), + branch4, inputscale(scale; activation = relu)), +) +end + +function mixed_6a() + branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 320, 256)..., + conv_norm((3, 3), 256, 256; pad = 1)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function block17(scale = 1.0f0) + branch1 = Chain(conv_norm((1, 1), 1088, 192)...) + branch2 = Chain(conv_norm((1, 1), 1088, 128)..., + conv_norm((1, 7), 128, 160; pad = (0, 3))..., + conv_norm((7, 1), 160, 192; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 384, 1088)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), + branch3, inputscale(scale; activation = relu)), +) +end + +function mixed_7a() + branch1 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; pad = 1)..., + conv_norm((3, 3), 288, 320; stride = 2)...) + branch4 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function block8(scale = 1.0f0; activation = identity) + branch1 = Chain(conv_norm((1, 1), 2080, 192)...) + branch2 = Chain(conv_norm((1, 1), 2080, 192)..., + conv_norm((1, 3), 192, 224; pad = (0, 1))..., + conv_norm((3, 1), 224, 256; pad = (1, 0))...) + branch3 = Chain(conv_norm((1, 1), 448, 2080)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), + branch3, inputscale(scale; activation)), +) +end + +""" + inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an InceptionResNetv2 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + MaxPool((3, 3); stride = 2), + conv_norm((3, 3), 64, 80)..., + conv_norm((3, 3), 80, 192)..., + MaxPool((3, 3); stride = 2), + mixed_5b(), + [block35(0.17f0) for _ in 1:10]..., + mixed_6a(), + [block17(0.10f0) for _ in 1:20]..., + mixed_7a(), + [block8(0.20f0) for _ in 1:9]..., + block8(; activation = relu), + conv_norm((1, 1), 2080, 1536)...) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(1536, nclasses)) + return Chain(body, head) +end + +""" + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an InceptionResNetv2 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `InceptionResNetv2` does not currently support pretrained weights. +""" +struct InceptionResNetv2 + layers::Any +end +@functor InceptionResNetv2 + +function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "InceptionResNetv2") + end + return InceptionResNetv2(layers) +end + +(m::InceptionResNetv2)(x) = m.layers(x) + +backbone(m::InceptionResNetv2) = m.layers[1] +classifier(m::InceptionResNetv2) = m.layers[2] diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl new file mode 100644 index 000000000..68b283838 --- /dev/null +++ b/src/convnets/inception/inceptionv3.jl @@ -0,0 +1,196 @@ +## Inceptionv3 + +""" + inceptionv3_a(inplanes, pool_proj) + +Create an Inception-v3 style-A module +(ref: Fig. 5 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps + - `pool_proj`: the number of output feature maps for the pooling projection +""" +function inceptionv3_a(inplanes, pool_proj) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) + branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, pool_proj)...) + return Parallel(cat_channels, + branch1x1, branch5x5, branch3x3, branch_pool) +end + +""" + inceptionv3_b(inplanes) + +Create an Inception-v3 style-B module +(ref: Fig. 10 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_b(inplanes) + branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; stride = 2)...) + branch_pool = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, + branch3x3_1, branch3x3_2, branch_pool) +end + +""" + inceptionv3_c(inplanes, inner_planes, n = 7) + +Create an Inception-v3 style-C module +(ref: Fig. 6 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps + - `inner_planes`: the number of output feature maps within each branch + - `n`: the "grid size" (kernel size) for the convolution layers +""" +function inceptionv3_c(inplanes, inner_planes, n = 7) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) + branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) + branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, 192)...) + return Parallel(cat_channels, + branch1x1, branch7x7_1, branch7x7_2, branch_pool) +end + +""" + inceptionv3_d(inplanes) + +Create an Inception-v3 style-D module +(ref: [pytorch](https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/inception.py#L322)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_d(inplanes) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((1, 7), 192, 192; pad = (0, 3))..., + conv_norm((7, 1), 192, 192; pad = (3, 0))..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch_pool = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, + branch3x3, branch7x7x3, branch_pool) +end + +""" + inceptionv3_e(inplanes) + +Create an Inception-v3 style-E module +(ref: Fig. 7 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_e(inplanes) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) + branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) + branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., + conv_norm((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, 192)...) + return Parallel(cat_channels, + branch1x1, + Chain(branch3x3_1, + Parallel(cat_channels, + branch3x3_1a, branch3x3_1b)), + Chain(branch3x3_2, + Parallel(cat_channels, + branch3x3_2a, branch3x3_2b)), + branch_pool) +end + +""" + inceptionv3(; nclasses = 1000) + +Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `nclasses`: the number of output classes +""" +function inceptionv3(; nclasses = 1000) + layer = Chain(Chain(conv_norm((3, 3), 3, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + MaxPool((3, 3); stride = 2), + conv_norm((1, 1), 64, 80)..., + conv_norm((3, 3), 80, 192)..., + MaxPool((3, 3); stride = 2), + inceptionv3_a(192, 32), + inceptionv3_a(256, 64), + inceptionv3_a(288, 64), + inceptionv3_b(288), + inceptionv3_c(768, 128), + inceptionv3_c(768, 160), + inceptionv3_c(768, 160), + inceptionv3_c(768, 192), + inceptionv3_d(768), + inceptionv3_e(1280), + inceptionv3_e(2048)), + Chain(AdaptiveMeanPool((1, 1)), + Dropout(0.2), + MLUtils.flatten, + Dense(2048, nclasses))) + return layer +end + +""" + Inceptionv3(; pretrain = false, nclasses = 1000) + +Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). +See also [`inceptionv3`](#). + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +!!! warning + + `Inceptionv3` does not currently support pretrained weights. +""" +struct Inceptionv3 + layers::Any +end + +function Inceptionv3(; pretrain = false, nclasses = 1000) + layers = inceptionv3(; nclasses = nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv3") + end + return Inceptionv3(layers) +end + +@functor Inceptionv3 + +(m::Inceptionv3)(x) = m.layers(x) + +backbone(m::Inceptionv3) = m.layers[1] +classifier(m::Inceptionv3) = m.layers[2] diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl new file mode 100644 index 000000000..bb03646ec --- /dev/null +++ b/src/convnets/inception/inceptionv4.jl @@ -0,0 +1,158 @@ +function mixed_3a() + return Parallel(cat_channels, + MaxPool((3, 3); stride = 2), + Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) +end + +function mixed_4a() + return Parallel(cat_channels, + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((3, 3), 64, 96)...), + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((1, 7), 64, 64; pad = (0, 3))..., + conv_norm((7, 1), 64, 64; pad = (3, 0))..., + conv_norm((3, 3), 64, 96)...)) +end + +function mixed_5a() + return Parallel(cat_channels, + Chain(conv_norm((3, 3), 192, 192; stride = 2)...), + MaxPool((3, 3); stride = 2)) +end + +function inceptionv4_a() + branch1 = Chain(conv_norm((1, 1), 384, 96)...) + branch2 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_a() + branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 384, 192)..., + conv_norm((3, 3), 192, 224; pad = 1)..., + conv_norm((3, 3), 224, 256; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_b() + branch1 = Chain(conv_norm((1, 1), 1024, 384)...) + branch2 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((1, 7), 192, 224; pad = (0, 3))..., + conv_norm((7, 1), 224, 256; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((7, 1), 192, 192; pad = (0, 3))..., + conv_norm((1, 7), 192, 224; pad = (3, 0))..., + conv_norm((7, 1), 224, 224; pad = (0, 3))..., + conv_norm((1, 7), 224, 256; pad = (3, 0))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_b() + branch1 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1024, 256)..., + conv_norm((1, 7), 256, 256; pad = (0, 3))..., + conv_norm((7, 1), 256, 320; pad = (3, 0))..., + conv_norm((3, 3), 320, 320; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_c() + branch1 = Chain(conv_norm((1, 1), 1536, 256)...) + branch2 = Chain(conv_norm((1, 1), 1536, 384)..., + Parallel(cat_channels, + Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) + branch3 = Chain(conv_norm((1, 1), 1536, 384)..., + conv_norm((3, 1), 384, 448; pad = (1, 0))..., + conv_norm((1, 3), 448, 512; pad = (0, 1))..., + Parallel(cat_channels, + Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +""" + inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Create an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + mixed_3a(), + mixed_4a(), + mixed_5a(), + inceptionv4_a(), + inceptionv4_a(), + inceptionv4_a(), + inceptionv4_a(), + reduction_a(), # mixed_6a + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + reduction_b(), # mixed_7a + inceptionv4_c(), + inceptionv4_c(), + inceptionv4_c()) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(1536, nclasses)) + return Chain(body, head) +end + +""" + Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `Inceptionv4` does not currently support pretrained weights. +""" +struct Inceptionv4 + layers::Any +end +@functor Inceptionv4 + +function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionv4(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv4") + end + return Inceptionv4(layers) +end + +(m::Inceptionv4)(x) = m.layers(x) + +backbone(m::Inceptionv4) = m.layers[1] +classifier(m::Inceptionv4) = m.layers[2] diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl new file mode 100644 index 000000000..6f4928385 --- /dev/null +++ b/src/convnets/inception/xception.jl @@ -0,0 +1,106 @@ +""" + xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, + grow_at_start = true) + +Create an Xception block. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `inchannels`: The number of channels in the input. + - `outchannels`: number of output channels. + - `nrepeats`: number of repeats of depthwise separable convolution layers. + - `stride`: stride by which to downsample the input. + - `start_with_relu`: if true, start the block with a ReLU activation. + - `grow_at_start`: if true, increase the number of channels at the first convolution. +""" +function xception_block(inchannels, outchannels, nrepeats; stride = 1, + start_with_relu = true, + grow_at_start = true) + if outchannels != inchannels || stride != 1 + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, + bias = false) + else + skip = [identity] + end + layers = [] + for i in 1:nrepeats + if grow_at_start + inc = i == 1 ? inchannels : outchannels + outc = outchannels + else + inc = inchannels + outc = i == nrepeats ? outchannels : inchannels + end + push!(layers, relu) + append!(layers, + depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, + use_bn = (false, false))) + push!(layers, BatchNorm(outc)) + end + layers = start_with_relu ? layers : layers[2:end] + push!(layers, MaxPool((3, 3); stride = stride, pad = 1)) + return Parallel(+, Chain(skip...), Chain(layers...)) +end + +""" + xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Xception model. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., + conv_norm((3, 3), 32, 64; bias = false)..., + xception_block(64, 128, 2; stride = 2, start_with_relu = false), + xception_block(128, 256, 2; stride = 2), + xception_block(256, 728, 2; stride = 2), + [xception_block(728, 728, 3) for _ in 1:8]..., + xception_block(728, 1024, 2; stride = 2, grow_at_start = false), + depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., + depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(2048, nclasses)) + return Chain(body, head) +end + +struct Xception + layers::Any +end +@functor Xception + +""" + Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Xception model. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `Xception` does not currently support pretrained weights. +""" +function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "xception") + end + return Xception(layers) +end + +(m::Xception)(x) = m.layers(x) + +backbone(m::Xception) = m.layers[1] +classifier(m::Xception) = m.layers[2] diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl deleted file mode 100644 index 06274c94e..000000000 --- a/src/convnets/mobilenet.jl +++ /dev/null @@ -1,339 +0,0 @@ -# MobileNetv1 - -""" - mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - fcsize = 1024, - nclasses = 1000) - -Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `dw`: Set true to use a depthwise separable convolution or false for regular convolution - + `o`: The number of output feature maps - + `s`: The stride of the convolutional kernel - + `r`: The number of time this configuration block is repeated - - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. - - `fcsize`: The intermediate fully-connected size between the convolution and final layers - - `nclasses`: The number of output classes -""" -function mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - fcsize = 1024, - nclasses = 1000) - layers = [] - for (dw, outch, stride, nrepeats) in config - outch = Int(outch * width_mult) - for _ in 1:nrepeats - layer = dw ? - depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outch, activation; stride = stride, - pad = 1, - bias = false) - append!(layers, layer) - inchannels = outch - end - end - - return Chain(Chain(layers), - Chain(GlobalMeanPool(), - MLUtils.flatten, - Dense(inchannels, fcsize, activation), - Dense(fcsize, nclasses))) -end - -const mobilenetv1_configs = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), -] - -""" - MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv1 model with the baseline configuration -([reference](https://arxiv.org/abs/1704.04861v1)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv1`](#). -""" -struct MobileNetv1 - layers::Any -end - -function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("MobileNetv1")) - end - return MobileNetv1(layers) -end - -@functor MobileNetv1 - -(m::MobileNetv1)(x) = m.layers(x) - -backbone(m::MobileNetv1) = m.layers[1] -classifier(m::MobileNetv1) = m.layers[2] - -# MobileNetv2 - -""" - mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - -Create a MobileNetv2 model. -([reference](https://arxiv.org/abs/1801.04381)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer - + `c`: The number of output feature maps - + `n`: The number of times a block is repeated - + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: The number of output classes -""" -function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - # building first layer - inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) - layers = [] - append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) - # building inverted residual blocks - for (t, c, n, s, a) in configs - outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) - for i in 1:n - push!(layers, - invertedresidual(3, inplanes, inplanes * t, outplanes, a; - stride = i == 1 ? s : 1)) - inplanes = outplanes - end - end - # building last several layers - outplanes = (width_mult > 1) ? - _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : - max_width - return Chain(Chain(Chain(layers), - conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(outplanes, nclasses))) -end - -# Layer configurations for MobileNetv2 -const mobilenetv2_configs = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), -] - -# Model definition for MobileNetv2 -struct MobileNetv2 - layers::Any -end - -""" - MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv2 model with the specified configuration. -([reference](https://arxiv.org/abs/1801.04381)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv2`](#). -""" -function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv2")) - if pretrain - loadpretrain!(layers, string("MobileNetv2")) - end - return MobileNetv2(layers) -end - -@functor MobileNetv2 - -(m::MobileNetv2)(x) = m.layers(x) - -backbone(m::MobileNetv2) = m.layers[1] -classifier(m::MobileNetv2) = m.layers[2] - -# MobileNetv3 - -""" - mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - -Create a MobileNetv3 model. -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - - `configs`: a "list of tuples" configuration for each layer that details: - - + `k::Integer` - The size of the convolutional kernel - + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - + `t::Integer` - The number of output feature maps for a given block - + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - + `s::Integer` - The stride of the convolutional kernel - + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: the number of output classes -""" -function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - # building first layer - inplanes = _round_channels(16 * width_mult, 8) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, - bias = false)) - explanes = 0 - # building inverted residual blocks - for (k, t, c, r, a, s) in configs - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - invertedresidual(k, inplanes, explanes, outplanes, a; - stride = s, reduction = r)) - inplanes = outplanes - end - # building last several layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel - classifier = Chain(Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses)) - return Chain(Chain(Chain(layers), - conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) -end - -# Configurations for small and large mode for MobileNetv3 -mobilenetv3_configs = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) - -# Model definition for MobileNetv3 -struct MobileNetv3 - layers::Any -end - -""" - MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv3 model with the specified configuration. -([reference](https://arxiv.org/abs/1905.02244)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -# Arguments - - - `mode`: :small or :large for the size of the model (see paper). - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of channels in the input. - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -See also [`Metalhead.mobilenetv3`](#). -""" -function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, - pretrain = false, nclasses = 1000) - @assert mode in [:large, :small] "`mode` has to be either :large or :small" - max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, - nclasses) - if pretrain - loadpretrain!(layers, string("MobileNetv3", mode)) - end - return MobileNetv3(layers) -end - -@functor MobileNetv3 - -(m::MobileNetv3)(x) = m.layers(x) - -backbone(m::MobileNetv3) = m.layers[1] -classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl new file mode 100644 index 000000000..4add739b4 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -0,0 +1,102 @@ +""" + mobilenetv1(width_mult, config; + activation = relu, + inchannels = 3, + fcsize = 1024, + nclasses = 1000) + +Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `dw`: Set true to use a depthwise separable convolution or false for regular convolution + + `o`: The number of output feature maps + + `s`: The stride of the convolutional kernel + + `r`: The number of time this configuration block is repeated + - `activate`: The activation function to use throughout the network + - `inchannels`: The number of input channels. + - `fcsize`: The intermediate fully-connected size between the convolution and final layers + - `nclasses`: The number of output classes +""" +function mobilenetv1(width_mult, config; + activation = relu, + inchannels = 3, + fcsize = 1024, + nclasses = 1000) + layers = [] + for (dw, outch, stride, nrepeats) in config + outch = Int(outch * width_mult) + for _ in 1:nrepeats + layer = dw ? + depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; + stride = stride, pad = 1, bias = false) : + conv_norm((3, 3), inchannels, outch, activation; stride = stride, + pad = 1, + bias = false) + append!(layers, layer) + inchannels = outch + end + end + + return Chain(Chain(layers), + Chain(GlobalMeanPool(), + MLUtils.flatten, + Dense(inchannels, fcsize, activation), + Dense(fcsize, nclasses))) +end + +const mobilenetv1_configs = [ + # dw, c, s, r + (false, 32, 2, 1), + (true, 64, 1, 1), + (true, 128, 2, 1), + (true, 128, 1, 1), + (true, 256, 2, 1), + (true, 256, 1, 1), + (true, 512, 2, 1), + (true, 512, 1, 5), + (true, 1024, 2, 1), + (true, 1024, 1, 1), +] + +""" + MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv1 model with the baseline configuration +([reference](https://arxiv.org/abs/1704.04861v1)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of input channels. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv1`](#). +""" +struct MobileNetv1 + layers::Any +end +@functor MobileNetv1 + +function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, + nclasses = 1000) + layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv1")) + end + return MobileNetv1(layers) +end + +(m::MobileNetv1)(x) = m.layers(x) + +backbone(m::MobileNetv1) = m.layers[1] +classifier(m::MobileNetv1) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl new file mode 100644 index 000000000..21c017b42 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -0,0 +1,97 @@ +""" + mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) + +Create a MobileNetv2 model. +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer + + `c`: The number of output feature maps + + `n`: The number of times a block is repeated + + `s`: The stride of the convolutional kernel + + `a`: The activation function used in the bottleneck layer + - `inchannels`: The number of input channels. + - `max_width`: The maximum number of feature maps in any layer of the network + - `nclasses`: The number of output classes +""" +function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) + # building first layer + inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) + layers = [] + append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + # building inverted residual blocks + for (t, c, n, s, a) in configs + outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) + for i in 1:n + push!(layers, + invertedresidual(3, inplanes, inplanes * t, outplanes, a; + stride = i == 1 ? s : 1)) + inplanes = outplanes + end + end + # building last several layers + outplanes = (width_mult > 1) ? + _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : + max_width + return Chain(Chain(Chain(layers), + conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)...), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, + Dense(outplanes, nclasses))) +end + +# Layer configurations for MobileNetv2 +const mobilenetv2_configs = [ + # t, c, n, s, a + (1, 16, 1, 1, relu6), + (6, 24, 2, 2, relu6), + (6, 32, 3, 2, relu6), + (6, 64, 4, 2, relu6), + (6, 96, 3, 1, relu6), + (6, 160, 3, 2, relu6), + (6, 320, 1, 1, relu6), +] + +# Model definition for MobileNetv2 +struct MobileNetv2 + layers::Any +end +@functor MobileNetv2 + +""" + MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv2 model with the specified configuration. +([reference](https://arxiv.org/abs/1801.04381)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of input channels. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv2`](#). +""" +function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, + nclasses = 1000) + layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) + pretrain && loadpretrain!(layers, string("MobileNetv2")) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end + return MobileNetv2(layers) +end + +(m::MobileNetv2)(x) = m.layers(x) + +backbone(m::MobileNetv2) = m.layers[1] +classifier(m::MobileNetv2) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl new file mode 100644 index 000000000..6bc444407 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -0,0 +1,129 @@ +""" + mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) + +Create a MobileNetv3 model. +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + + - `configs`: a "list of tuples" configuration for each layer that details: + + + `k::Integer` - The size of the convolutional kernel + + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer + + `t::Integer` - The number of output feature maps for a given block + + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + + `s::Integer` - The stride of the convolutional kernel + + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) + - `inchannels`: The number of input channels. + - `max_width`: The maximum number of feature maps in any layer of the network + - `nclasses`: the number of output classes +""" +function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) + # building first layer + inplanes = _round_channels(16 * width_mult, 8) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, + bias = false)) + explanes = 0 + # building inverted residual blocks + for (k, t, c, r, a, s) in configs + # inverted residual layers + outplanes = _round_channels(c * width_mult, 8) + explanes = _round_channels(inplanes * t, 8) + push!(layers, + invertedresidual(k, inplanes, explanes, outplanes, a; + stride = s, reduction = r)) + inplanes = outplanes + end + # building last several layers + output_channel = max_width + output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : + output_channel + classifier = Chain(Dense(explanes, output_channel, hardswish), + Dropout(0.2), + Dense(output_channel, nclasses)) + return Chain(Chain(Chain(layers), + conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)...), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) +end + +# Configurations for small and large mode for MobileNetv3 +mobilenetv3_configs = Dict(:small => [ + # k, t, c, SE, a, s + (3, 1, 16, 4, relu, 2), + (3, 4.5, 24, nothing, relu, 2), + (3, 3.67, 24, nothing, relu, 1), + (5, 4, 40, 4, hardswish, 2), + (5, 6, 40, 4, hardswish, 1), + (5, 6, 40, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 2), + (5, 6, 96, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 1), + ], + :large => [ + # k, t, c, SE, a, s + (3, 1, 16, nothing, relu, 1), + (3, 4, 24, nothing, relu, 2), + (3, 3, 24, nothing, relu, 1), + (5, 3, 40, 4, relu, 2), + (5, 3, 40, 4, relu, 1), + (5, 3, 40, 4, relu, 1), + (3, 6, 80, nothing, hardswish, 2), + (3, 2.5, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 2), + (5, 6, 160, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 1), + ]) + +# Model definition for MobileNetv3 +struct MobileNetv3 + layers::Any +end +@functor MobileNetv3 + +""" + MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv3 model with the specified configuration. +([reference](https://arxiv.org/abs/1905.02244)). +Set `pretrain = true` to load the model with pre-trained weights for ImageNet. + +# Arguments + + - `mode`: :small or :large for the size of the model (see paper). + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of channels in the input. + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +See also [`Metalhead.mobilenetv3`](#). +""" +function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, + pretrain = false, nclasses = 1000) + @assert mode in [:large, :small] "`mode` has to be either :large or :small" + max_width = (mode == :large) ? 1280 : 1024 + layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, + nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv3", mode)) + end + return MobileNetv3(layers) +end + +(m::MobileNetv3)(x) = m.layers(x) + +backbone(m::MobileNetv3) = m.layers[1] +classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index ad60814d4..9af497844 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -32,9 +32,8 @@ function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activati norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = expansion_factor(basicblock) first_planes = planes ÷ reduction_factor - outplanes = planes * expansion + outplanes = planes * expansion_factor(basicblock) conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm, stride, pad = 1, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm, @@ -81,10 +80,9 @@ function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor - outplanes = planes * expansion + outplanes = planes * expansion_factor(bottleneck) conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm, @@ -322,8 +320,8 @@ function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; end # block-layer configurations for ResNet-like models -const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), - 34 => (basicblock, [3, 4, 6, 3]), - 50 => (bottleneck, [3, 4, 6, 3]), - 101 => (bottleneck, [3, 4, 23, 3]), - 152 => (bottleneck, [3, 8, 36, 3])) +const resnet_configs = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index ffbee32dc..97a9fb7a4 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -1,9 +1,3 @@ -const resnet_shortcuts = Dict(18 => [:A, :B, :B, :B], - 34 => [:A, :B, :B, :B], - 50 => [:B, :B, :B, :B], - 101 => [:B, :B, :B, :B], - 152 => [:B, :B, :B, :B]) - """ ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @@ -28,17 +22,16 @@ struct ResNet end @functor ResNet -(m::ResNet)(x) = m.layers(x) - function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]..., resnet_shortcuts[depth]; inchannels, nclasses) + _checkconfig(depth, keys(resnet_configs)) + layers = resnet(resnet_configs[depth]..., resnet_shortcuts[depth]; inchannels, nclasses) if pretrain loadpretrain!(layers, string("ResNet", depth)) end return ResNet(layers) end +(m::ResNet)(x) = m.layers(x) + backbone(m::ResNet) = m.layers[1] classifier(m::ResNet) = m.layers[2] diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index cee2e4757..47e81d44d 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -29,9 +29,8 @@ end function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width) + _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 26bfbb6c1..ae25a73e8 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -6,7 +6,7 @@ Creates a SEResNet model with the specified depth. # Arguments - - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: the number of input channels. - `nclasses`: the number of output classes @@ -25,9 +25,8 @@ end (m::SEResNet)(x) = m.layers(x) function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, + _checkconfig(depth, keys(resnet_configs)) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses, attn_fn = planes -> squeeze_excite(planes)) if pretrain loadpretrain!(layers, string("SEResNet", depth)) @@ -69,9 +68,8 @@ end function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width, + _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width, attn_fn = planes -> squeeze_excite(planes)) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index df458f9ff..abcdd63f8 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -15,11 +15,7 @@ function fire(inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes) branch_1 = Conv((1, 1), inplanes => squeeze_planes, relu) branch_2 = Conv((1, 1), squeeze_planes => expand1x1_planes, relu) branch_3 = Conv((3, 3), squeeze_planes => expand3x3_planes, relu; pad = 1) - - return Chain(branch_1, - Parallel(cat_channels, - branch_2, - branch_3)) + return Chain(branch_1, Parallel(cat_channels, branch_2, branch_3)) end """ @@ -29,24 +25,22 @@ Create a SqueezeNet ([reference](https://arxiv.org/abs/1602.07360v4)). """ function squeezenet() - layers = Chain(Chain(Conv((3, 3), 3 => 64, relu; stride = 2), - MaxPool((3, 3); stride = 2), - fire(64, 16, 64, 64), - fire(128, 16, 64, 64), - MaxPool((3, 3); stride = 2), - fire(128, 32, 128, 128), - fire(256, 32, 128, 128), - MaxPool((3, 3); stride = 2), - fire(256, 48, 192, 192), - fire(384, 48, 192, 192), - fire(384, 64, 256, 256), - fire(512, 64, 256, 256), - Dropout(0.5), - Conv((1, 1), 512 => 1000, relu)), - AdaptiveMeanPool((1, 1)), - MLUtils.flatten) - - return layers + return Chain(Chain(Conv((3, 3), 3 => 64, relu; stride = 2), + MaxPool((3, 3); stride = 2), + fire(64, 16, 64, 64), + fire(128, 16, 64, 64), + MaxPool((3, 3); stride = 2), + fire(128, 32, 128, 128), + fire(256, 32, 128, 128), + MaxPool((3, 3); stride = 2), + fire(256, 48, 192, 192), + fire(384, 48, 192, 192), + fire(384, 64, 256, 256), + fire(512, 64, 256, 256), + Dropout(0.5), + Conv((1, 1), 512 => 1000, relu)), + AdaptiveMeanPool((1, 1)), + MLUtils.flatten) end """ @@ -65,6 +59,7 @@ See also [`squeezenet`](#). struct SqueezeNet layers::Any end +@functor SqueezeNet function SqueezeNet(; pretrain = false) layers = squeezenet() @@ -74,8 +69,6 @@ function SqueezeNet(; pretrain = false) return SqueezeNet(layers) end -@functor SqueezeNet - (m::SqueezeNet)(x) = m.layers(x) backbone(m::SqueezeNet) = m.layers[1] diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index c8a2e6344..3a1a8ac10 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -99,15 +99,15 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr return Chain(Chain(conv), class) end -const vgg_conv_config = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], - :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], - :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], - :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) +const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], + :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], + :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], + :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) -const vgg_config = Dict(11 => :A, - 13 => :B, - 16 => :D, - 19 => :E) +const vgg_configs = Dict(11 => :A, + 13 => :B, + 16 => :D, + 19 => :E) struct VGG layers::Any @@ -153,8 +153,8 @@ See also [`VGG`](#). - `pretrain`: set to `true` to load pre-trained model weights for ImageNet """ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000) - @assert depth in keys(vgg_config) "depth must be from one in $(sort(collect(keys(vgg_config))))" - model = VGG((224, 224); config = vgg_conv_config[vgg_config[depth]], + _checkconfig(depth, keys(vgg_configs)) + model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]], inchannels = 3, batchnorm = batchnorm, nclasses = nclasses, diff --git a/src/mixers/core.jl b/src/mixers/core.jl new file mode 100644 index 000000000..6a55f048e --- /dev/null +++ b/src/mixers/core.jl @@ -0,0 +1,43 @@ +""" + mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., + depth = 12, nclasses = 1000, kwargs...) + +Creates a model with the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)). + +# Arguments + + - `block`: the type of mixer block to use in the model - architecture dependent + (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) + - `imsize`: the size of the input image + - `inchannels`: the number of input channels + - `norm_layer`: the normalization layer to use in the model + - `patch_size`: the size of the patches + - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) + - `drop_path_rate`: Stochastic depth rate + - `depth`: the number of blocks in the model + - `nclasses`: number of output classes + - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if + not specified. +""" +function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, + norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), + embedplanes = 512, drop_path_rate = 0.0, + depth = 12, nclasses = 1000, kwargs...) + npatches = prod(imsize .÷ patch_size) + dp_rates = linear_scheduler(drop_path_rate; depth) + layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), + Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], + kwargs...) + for i in 1:depth])) + classification_head = Chain(norm_layer(embedplanes), seconddimmean, + Dense(embedplanes, nclasses)) + return Chain(layers, classification_head) +end + +# Configurations for MLPMixer models +mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), + :base => Dict(:depth => 12, :planes => 768), + :large => Dict(:depth => 24, :planes => 1024), + :huge => Dict(:depth => 32, :planes => 1280)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl new file mode 100644 index 000000000..4e681e9b4 --- /dev/null +++ b/src/mixers/gmlp.jl @@ -0,0 +1,110 @@ +""" + SpatialGatingUnit(norm, proj) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `norm`: the normalisation layer to use + - `proj`: the projection layer to use +""" +struct SpatialGatingUnit{T, F} + norm::T + proj::F +end +@functor SpatialGatingUnit + +""" + SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `norm_layer`: the normalisation layer to use +""" +function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) + gateplanes = planes ÷ 2 + norm = norm_layer(gateplanes) + proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) + return SpatialGatingUnit(norm, proj) +end + +function (m::SpatialGatingUnit)(x) + u, v = chunk(x, 2; dims = 1) + v = m.norm(v) + v = m.proj(permutedims(v, (2, 1, 3))) + return u .* permutedims(v, (2, 1, 3)) +end + +""" + spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, + norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, + activation = gelu) + +Creates a feedforward block based on the gMLP model architecture described in the paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block + - `norm_layer`: the normalisation layer to use + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks +""" +function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, + mlp_layer = gated_mlp_block, dropout_rate = 0.0, + drop_path_rate = 0.0, + activation = gelu) + channelplanes = Int(mlp_ratio * planes) + sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) + return SkipConnection(Chain(norm_layer(planes), + mlp_layer(sgu, planes, channelplanes; activation, + dropout_rate), + DropPath(drop_path_rate)), +) +end + +struct gMLP + layers::Any +end +@functor gMLP + +""" + gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) + +Creates a model with the gMLP architecture. +([reference](https://arxiv.org/abs/2105.08050)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(mixer_configs)) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, + patch_size, embedplanes, drop_path_rate, depth, nclasses) + return gMLP(layers) +end + +(m::gMLP)(x) = m.layers(x) + +backbone(m::gMLP) = m.layers[1] +classifier(m::gMLP) = m.layers[2] diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl new file mode 100644 index 000000000..e3da17a23 --- /dev/null +++ b/src/mixers/mlpmixer.jl @@ -0,0 +1,69 @@ +""" + mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout_rate = 0., drop_path_rate = 0., activation = gelu) + +Creates a feedforward block for the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP + and/or the channel mixing MLP as a ratio to the number of planes in the block. + - `mlp_layer`: the MLP layer to use in the block + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks +""" +function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] + return Chain(SkipConnection(Chain(LayerNorm(planes), + swapdims((2, 1, 3)), + mlp_layer(npatches, tokenplanes; activation, + dropout_rate), + swapdims((2, 1, 3)), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(LayerNorm(planes), + mlp_layer(planes, channelplanes; activation, + dropout_rate), + DropPath(drop_path_rate)), +)) +end + +struct MLPMixer + layers::Any +end +@functor MLPMixer + +""" + MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) + +Creates a model with the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(mixer_configs)) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, + nclasses) + return MLPMixer(layers) +end + +(m::MLPMixer)(x) = m.layers(x) + +backbone(m::MLPMixer) = m.layers[1] +classifier(m::MLPMixer) = m.layers[2] diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl new file mode 100644 index 000000000..38163702c --- /dev/null +++ b/src/mixers/resmlp.jl @@ -0,0 +1,72 @@ +""" + resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, + activation = gelu, λ = 1e-4) + +Creates a block for the ResMixer architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block + - `mlp_layer`: the MLP block to use + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks + - `λ`: initialisation constant for the LayerScale +""" +function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, + λ = 1e-4) + return Chain(SkipConnection(Chain(Flux.Scale(planes), + swapdims((2, 1, 3)), + Dense(npatches, npatches), + swapdims((2, 1, 3)), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(Flux.Scale(planes), + mlp_layer(planes, Int(mlp_ratio * planes); + dropout_rate, + activation), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +)) +end + +struct ResMLP + layers::Any +end +@functor ResMLP + +""" + ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + +Creates a model with the ResMLP architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(mixer_configs)) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, + drop_path_rate, depth, nclasses) + return ResMLP(layers) +end + +(m::ResMLP)(x) = m.layers(x) + +backbone(m::ResMLP) = m.layers[1] +classifier(m::ResMLP) = m.layers[2] diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl deleted file mode 100644 index ac66be28b..000000000 --- a/src/other/mlpmixer.jl +++ /dev/null @@ -1,302 +0,0 @@ -""" - mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout_rate = 0., drop_path_rate = 0., activation = gelu) - -Creates a feedforward block for the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP - and/or the channel mixing MLP as a ratio to the number of planes in the block. - - `mlp_layer`: the MLP layer to use in the block - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks -""" -function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) - tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] - return Chain(SkipConnection(Chain(LayerNorm(planes), - swapdims((2, 1, 3)), - mlp_layer(npatches, tokenplanes; activation, - dropout_rate), - swapdims((2, 1, 3)), - DropPath(drop_path_rate)), +), - SkipConnection(Chain(LayerNorm(planes), - mlp_layer(planes, channelplanes; activation, - dropout_rate), - DropPath(drop_path_rate)), +)) -end - -""" - mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., - depth = 12, nclasses = 1000, kwargs...) - -Creates a model with the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)). - -# Arguments - - - `block`: the type of mixer block to use in the model - architecture dependent - (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) - - `imsize`: the size of the input image - - `inchannels`: the number of input channels - - `norm_layer`: the normalization layer to use in the model - - `patch_size`: the size of the patches - - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) - - `drop_path_rate`: Stochastic depth rate - - `depth`: the number of blocks in the model - - `nclasses`: number of output classes - - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if - not specified. -""" -function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, - norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), - embedplanes = 512, drop_path_rate = 0.0, - depth = 12, nclasses = 1000, kwargs...) - npatches = prod(imsize .÷ patch_size) - dp_rates = linear_scheduler(drop_path_rate; depth) - layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), - Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], - kwargs...) - for i in 1:depth])) - - classification_head = Chain(norm_layer(embedplanes), seconddimmean, - Dense(embedplanes, nclasses)) - return Chain(layers, classification_head) -end - -# Configurations for MLPMixer models -mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), - :base => Dict(:depth => 12, :planes => 768), - :large => Dict(:depth => 24, :planes => 1024), - :huge => Dict(:depth => 32, :planes => 1280)) - -struct MLPMixer - layers::Any -end - -""" - MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) - -Creates a model with the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, - nclasses) - return MLPMixer(layers) -end - -@functor MLPMixer - -(m::MLPMixer)(x) = m.layers(x) - -backbone(m::MLPMixer) = m.layers[1] -classifier(m::MLPMixer) = m.layers[2] - -""" - resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, - activation = gelu, λ = 1e-4) - -Creates a block for the ResMixer architecture. -([reference](https://arxiv.org/abs/2105.03404)). - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number - of planes in the block - - `mlp_layer`: the MLP block to use - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks - - `λ`: initialisation constant for the LayerScale -""" -function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, - λ = 1e-4) - return Chain(SkipConnection(Chain(Flux.Scale(planes), - swapdims((2, 1, 3)), - Dense(npatches, npatches), - swapdims((2, 1, 3)), - LayerScale(planes, λ), - DropPath(drop_path_rate)), +), - SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); - dropout_rate, - activation), - LayerScale(planes, λ), - DropPath(drop_path_rate)), +)) -end - -struct ResMLP - layers::Any -end - -""" - ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), - drop_path_rate = 0., nclasses = 1000) - -Creates a model with the ResMLP architecture. -([reference](https://arxiv.org/abs/2105.03404)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, - drop_path_rate, depth, nclasses) - return ResMLP(layers) -end - -@functor ResMLP - -(m::ResMLP)(x) = m.layers(x) - -backbone(m::ResMLP) = m.layers[1] -classifier(m::ResMLP) = m.layers[2] - -""" - SpatialGatingUnit(norm, proj) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `norm`: the normalisation layer to use - - `proj`: the projection layer to use -""" -struct SpatialGatingUnit{T, F} - norm::T - proj::F -end - -""" - SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `norm_layer`: the normalisation layer to use -""" -function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - gateplanes = planes ÷ 2 - norm = norm_layer(gateplanes) - proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) - return SpatialGatingUnit(norm, proj) -end - -@functor SpatialGatingUnit - -function (m::SpatialGatingUnit)(x) - u, v = chunk(x, 2; dims = 1) - v = m.norm(v) - v = m.proj(permutedims(v, (2, 1, 3))) - return u .* permutedims(v, (2, 1, 3)) -end - -""" - spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, - activation = gelu) - -Creates a feedforward block based on the gMLP model architecture described in the paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number - of planes in the block - - `norm_layer`: the normalisation layer to use - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks -""" -function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, dropout_rate = 0.0, - drop_path_rate = 0.0, - activation = gelu) - channelplanes = Int(mlp_ratio * planes) - sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) - return SkipConnection(Chain(norm_layer(planes), - mlp_layer(sgu, planes, channelplanes; activation, - dropout_rate), - DropPath(drop_path_rate)), +) -end - -struct gMLP - layers::Any -end - -""" - gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) - -Creates a model with the gMLP architecture. -([reference](https://arxiv.org/abs/2105.08050)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, - patch_size, embedplanes, drop_path_rate, depth, nclasses) - return gMLP(layers) -end - -@functor gMLP - -(m::gMLP)(x) = m.layers(x) - -backbone(m::gMLP) = m.layers[1] -classifier(m::gMLP) = m.layers[2] diff --git a/src/utilities.jl b/src/utilities.jl index b420efd0b..60284a343 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -68,5 +68,11 @@ end Returns the dropout rates for a given depth using the linear scaling rule. """ function linear_scheduler(drop_rate = 0.0; depth, start_value = 0.0) - return LinRange{Float32}(start_value, drop_rate, depth) + return LinRange(start_value, drop_rate, depth) +end + +# Utility function for depth and configuration checks in models +function _checkconfig(config, configs) + @assert config in configs + return "Invalid configuration. Must be one of $(sort(collect(configs)))." end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 856b64697..93eba09ee 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -94,10 +94,11 @@ See also [`Metalhead.vit`](#). struct ViT layers::Any end +@functor ViT function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) - @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" + _checkconfig(mode, keys(vit_configs)) kwargs = vit_configs[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) return ViT(layers) @@ -107,5 +108,3 @@ end backbone(m::ViT) = m.layers[1] classifier(m::ViT) = m.layers[2] - -@functor ViT diff --git a/test/convnets.jl b/test/convnets.jl index 601f51421..860ac422b 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -46,7 +46,7 @@ end (dropout_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), ] @testset for drop_rates in drop_list - m = Metalhead.resnet(block_fn, layers; drop_rates) + m = Metalhead.resnet(block_fn, layers; drop_rates...) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() @@ -75,7 +75,7 @@ end end @testset "SEResNet" begin - @testset for depth in [50, 101, 152] + @testset for depth in [18, 34, 50, 101, 152] m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) if (SEResNet, depth) in PRETRAINED_MODELS diff --git a/test/other.jl b/test/mixers.jl similarity index 76% rename from test/other.jl rename to test/mixers.jl index df97d4f5f..5e2e8cf27 100644 --- a/test/other.jl +++ b/test/mixers.jl @@ -1,5 +1,5 @@ @testset "MLPMixer" begin - @testset for mode in [:small, :base] # :large, # :huge] + @testset for mode in [:small, :base, :large, :huge] @testset for drop_path_rate in [0.0, 0.5] m = MLPMixer(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -10,7 +10,7 @@ end @testset "ResMLP" begin - @testset for mode in [:small, :base] # :large, # :huge] + @testset for mode in [:small, :base, :large, :huge] @testset for drop_path_rate in [0.0, 0.5] m = ResMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -21,7 +21,7 @@ end end @testset "gMLP" begin - @testset for mode in [:small, :base] # :large, # :huge] + @testset for mode in [:small, :base, :large, :huge] @testset for drop_path_rate in [0.0, 0.5] m = gMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) diff --git a/test/runtests.jl b/test/runtests.jl index 1a8c77f25..eaef97472 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,9 +61,9 @@ end GC.safepoint() GC.gc() -# Other tests -@testset verbose = true "Other" begin - include("other.jl") +# Mixer tests +@testset verbose = true "Mixers" begin + include("mixers.jl") end GC.safepoint() @@ -71,5 +71,5 @@ GC.gc() # ViT tests @testset verbose = true "ViTs" begin - include("vit-based.jl") + include("vits.jl") end diff --git a/test/vit-based.jl b/test/vits.jl similarity index 65% rename from test/vit-based.jl rename to test/vits.jl index e889b07be..13733ddec 100644 --- a/test/vit-based.jl +++ b/test/vits.jl @@ -1,5 +1,5 @@ @testset "ViT" begin - for mode in [:small, :base, :large] # :tiny, #,:huge, :giant, :gigantic] + for mode in [:tiny, :small, :base, :large, :huge] #:giant, #:gigantic m = ViT(mode) @test size(m(x_256)) == (1000, 1) @test gradtest(m, x_256) From ce1da459ab4c49195966c4d32a7807d952b86bf1 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 23 Jul 2022 11:16:37 +0530 Subject: [PATCH 52/64] Reorganisation And some more formatting --- .github/workflows/CI.yml | 3 +- src/Metalhead.jl | 34 +- src/convnets/alexnet.jl | 4 +- src/convnets/convmixer.jl | 26 +- src/convnets/convnext.jl | 25 +- src/convnets/densenet.jl | 14 +- src/convnets/efficientnet.jl | 33 +- src/convnets/inception.jl | 605 -------------------- src/convnets/{ => inception}/googlenet.jl | 6 +- src/convnets/inception/inceptionresnetv2.jl | 133 +++++ src/convnets/inception/inceptionv3.jl | 196 +++++++ src/convnets/inception/inceptionv4.jl | 158 +++++ src/convnets/inception/xception.jl | 106 ++++ src/convnets/mobilenet.jl | 339 ----------- src/convnets/mobilenet/mobilenetv1.jl | 102 ++++ src/convnets/mobilenet/mobilenetv2.jl | 97 ++++ src/convnets/mobilenet/mobilenetv3.jl | 129 +++++ src/convnets/resnets/core.jl | 16 +- src/convnets/resnets/resnet.jl | 15 +- src/convnets/resnets/resnext.jl | 5 +- src/convnets/resnets/seresnet.jl | 12 +- src/convnets/squeezenet.jl | 43 +- src/convnets/vgg.jl | 20 +- src/layers/drop.jl | 17 +- src/mixers/core.jl | 43 ++ src/mixers/gmlp.jl | 110 ++++ src/mixers/mlpmixer.jl | 69 +++ src/mixers/resmlp.jl | 72 +++ src/other/mlpmixer.jl | 302 ---------- src/utilities.jl | 12 +- src/vit-based/vit.jl | 5 +- test/convnets.jl | 4 +- test/{other.jl => mixers.jl} | 6 +- test/runtests.jl | 8 +- test/{vit-based.jl => vits.jl} | 2 +- 35 files changed, 1354 insertions(+), 1417 deletions(-) delete mode 100644 src/convnets/inception.jl rename src/convnets/{ => inception}/googlenet.jl (99%) create mode 100644 src/convnets/inception/inceptionresnetv2.jl create mode 100644 src/convnets/inception/inceptionv3.jl create mode 100644 src/convnets/inception/inceptionv4.jl create mode 100644 src/convnets/inception/xception.jl delete mode 100644 src/convnets/mobilenet.jl create mode 100644 src/convnets/mobilenet/mobilenetv1.jl create mode 100644 src/convnets/mobilenet/mobilenetv2.jl create mode 100644 src/convnets/mobilenet/mobilenetv3.jl create mode 100644 src/mixers/core.jl create mode 100644 src/mixers/gmlp.jl create mode 100644 src/mixers/mlpmixer.jl create mode 100644 src/mixers/resmlp.jl delete mode 100644 src/other/mlpmixer.jl rename test/{other.jl => mixers.jl} (76%) rename test/{vit-based.jl => vits.jl} (65%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e148bdff2..32882857b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,9 +34,8 @@ jobs: - '["InceptionResNetv2", "Xception"]' - '"DenseNet"' - '["ConvNeXt", "ConvMixer"]' - # - '"ConvMixer"' - '"ViT"' - - '"Other"' + - '"Mixers"' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 67731825e..3c8469dd2 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -21,26 +21,38 @@ using .Layers # CNN models include("convnets/alexnet.jl") include("convnets/vgg.jl") -include("convnets/inception.jl") -include("convnets/googlenet.jl") -include("convnets/densenet.jl") -include("convnets/squeezenet.jl") -include("convnets/mobilenet.jl") -include("convnets/efficientnet.jl") -include("convnets/convnext.jl") -include("convnets/convmixer.jl") ## ResNets include("convnets/resnets/core.jl") include("convnets/resnets/resnet.jl") include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") +## Inceptions +include("convnets/inception/googlenet.jl") +include("convnets/inception/inceptionv3.jl") +include("convnets/inception/inceptionv4.jl") +include("convnets/inception/inceptionresnetv2.jl") +include("convnets/inception/xception.jl") +## MobileNets +include("convnets/mobilenet/mobilenetv1.jl") +include("convnets/mobilenet/mobilenetv2.jl") +include("convnets/mobilenet/mobilenetv3.jl") +## Others +include("convnets/densenet.jl") +include("convnets/squeezenet.jl") +include("convnets/efficientnet.jl") +include("convnets/convnext.jl") +include("convnets/convmixer.jl") -# Other models -include("other/mlpmixer.jl") +# Mixers +include("mixers/core.jl") +include("mixers/mlpmixer.jl") +include("mixers/resmlp.jl") +include("mixers/gmlp.jl") -# ViT-based models +# ViTs include("vit-based/vit.jl") +# Load pretrained weights include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 87f2c288e..8ff65ffef 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -24,7 +24,6 @@ function alexnet(; nclasses = 1000) Dropout(0.5), Dense(4096, 4096, relu), Dense(4096, nclasses))) - return layers end @@ -46,6 +45,7 @@ See also [`alexnet`](#). struct AlexNet layers::Any end +@functor AlexNet function AlexNet(; pretrain = false, nclasses = 1000) layers = alexnet(; nclasses = nclasses) @@ -55,8 +55,6 @@ function AlexNet(; pretrain = false, nclasses = 1000) return AlexNet(layers) end -@functor AlexNet - (m::AlexNet)(x) = m.layers(x) backbone(m::AlexNet) = m.layers[1] diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index f3566e278..6547ba4fb 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -28,13 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), return Chain(Chain(stem..., Chain(blocks)), head) end -convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9), - :patch_size => (7, 7)), - :small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7), - :patch_size => (7, 7)), - :large => Dict(:planes => 1024, :depth => 20, +convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9), - :patch_size => (7, 7))) + :patch_size => (7, 7)), + :small => Dict(:planes => 768, :depth => 32, + :kernel_size => (7, 7), + :patch_size => (7, 7)), + :large => Dict(:planes => 1024, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7))) """ ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) @@ -52,19 +54,19 @@ Creates a ConvMixer model. struct ConvMixer layers::Any end +@functor ConvMixer function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) - planes = convmixer_config[mode][:planes] - depth = convmixer_config[mode][:depth] - kernel_size = convmixer_config[mode][:kernel_size] - patch_size = convmixer_config[mode][:patch_size] + _checkconfig(mode, keys(convmixer_configs)) + planes = convmixer_configs[mode][:planes] + depth = convmixer_configs[mode][:depth] + kernel_size = convmixer_configs[mode][:kernel_size] + patch_size = convmixer_configs[mode][:patch_size] layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses) return ConvMixer(layers) end -@functor ConvMixer - (m::ConvMixer)(x) = m.layers(x) backbone(m::ConvMixer) = m.layers[1] diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 383e5d128..052192fec 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -66,20 +66,16 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 end # Configurations for ConvNeXt models -convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3], - :planes => [96, 192, 384, 768]), - :small => Dict(:depths => [3, 3, 27, 3], - :planes => [96, 192, 384, 768]), - :base => Dict(:depths => [3, 3, 27, 3], - :planes => [128, 256, 512, 1024]), - :large => Dict(:depths => [3, 3, 27, 3], - :planes => [192, 384, 768, 1536]), - :xlarge => Dict(:depths => [3, 3, 27, 3], - :planes => [256, 512, 1024, 2048])) +convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), + :small => ([3, 3, 27, 3], [96, 192, 384, 768]), + :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), + :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), + :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) struct ConvNeXt layers::Any end +@functor ConvNeXt """ ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000) @@ -98,17 +94,12 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert mode in keys(convnext_configs) - "`size` must be one of $(collect(keys(convnext_configs)))" - depths = convnext_configs[mode][:depths] - planes = convnext_configs[mode][:planes] - layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses) + _checkconfig(mode, keys(convnext_configs)) + layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses) return ConvNeXt(layers) end (m::ConvNeXt)(x) = m.layers(x) -@functor ConvNeXt - backbone(m::ConvNeXt) = m.layers[1] classifier(m::ConvNeXt) = m.layers[2] diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 878e3fc8d..0c5bd6ad6 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -140,14 +140,14 @@ end backbone(m::DenseNet) = m.layers[1] classifier(m::DenseNet) = m.layers[2] -const densenet_config = Dict(121 => (6, 12, 24, 16), - 161 => (6, 12, 36, 24), - 169 => (6, 12, 32, 32), - 201 => (6, 12, 48, 32)) +const densenet_configs = Dict(121 => (6, 12, 24, 16), + 161 => (6, 12, 36, 24), + 169 => (6, 12, 32, 32), + 201 => (6, 12, 48, 32)) """ DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - DenseNet(transition_config::NTuple{N,Integer}) + DenseNet(transition_configs::NTuple{N,Integer}) Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201) ([reference](https://arxiv.org/abs/1608.06993)). @@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. See also [`Metalhead.densenet`](#). """ function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - @assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))." - model = DenseNet(densenet_config[config]; nclasses = nclasses) + _checkconfig(config, keys(densenet_configs)) + model = DenseNet(densenet_configs[config]; nclasses = nclasses) if pretrain loadpretrain!(model, string("DenseNet", config)) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index aeacc8092..122fd512a 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -1,5 +1,5 @@ """ - efficientnet(scalings, block_config; + efficientnet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). @@ -8,7 +8,7 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `scalings`: global width and depth scaling (given as a tuple) - - `block_config`: configuration for each inverted residual block, + - `block_configs`: configuration for each inverted residual block, given as a vector of tuples with elements: + `n`: number of block repetitions (will be scaled by global depth scaling) @@ -22,22 +22,19 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `max_width`: maximum number of output channels before the fully connected classification blocks """ -function efficientnet(scalings, block_config; +function efficientnet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - out_channels = _round_channels(scalew(32), 8) stem = conv_norm((3, 3), inchannels, out_channels, swish; bias = false, stride = 2, pad = SamePad()) - blocks = [] - for (n, k, s, e, i, o) in block_config + for (n, k, s, e, i, o) in block_configs in_channels = _round_channels(scalew(i), 8) out_channels = _round_channels(scalew(o), 8) repeats = scaled(n) - push!(blocks, invertedresidual(k, in_channels, in_channels * e, out_channels, swish; stride = s, reduction = 4)) @@ -48,13 +45,10 @@ function efficientnet(scalings, block_config; end end blocks = Chain(blocks...) - head_out_channels = _round_channels(max_width, 8) head = conv_norm((1, 1), out_channels, head_out_channels, swish; bias = false, pad = SamePad()) - top = Dense(head_out_channels, nclasses) - return Chain(Chain([stem..., blocks, head...]), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top)) end @@ -92,9 +86,10 @@ const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)), struct EfficientNet layers::Any end +@functor EfficientNet """ - EfficientNet(scalings, block_config; + EfficientNet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). @@ -104,7 +99,7 @@ See also [`efficientnet`](#). - `scalings`: global width and depth scaling (given as a tuple) - - `block_config`: configuration for each inverted residual block, + - `block_configs`: configuration for each inverted residual block, given as a vector of tuples with elements: + `n`: number of block repetitions (will be scaled by global depth scaling) @@ -118,17 +113,12 @@ See also [`efficientnet`](#). - `max_width`: maximum number of output channels before the fully connected classification blocks """ -function EfficientNet(scalings, block_config; +function EfficientNet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) - layers = efficientnet(scalings, block_config; - inchannels = inchannels, - nclasses = nclasses, - max_width = max_width) + layers = efficientnet(scalings, block_configs; inchannels, nclasses, max_width) return EfficientNet(layers) end -@functor EfficientNet - (m::EfficientNet)(x) = m.layers(x) backbone(m::EfficientNet) = m.layers[1] @@ -147,11 +137,8 @@ See also [`efficientnet`](#). - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) - @assert name in keys(efficientnet_global_configs) - "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" - + _checkconfig(name, keys(efficientnet_global_configs)) model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) pretrain && loadpretrain!(model, string("efficientnet-", name)) - return model end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl deleted file mode 100644 index 4fd43f26c..000000000 --- a/src/convnets/inception.jl +++ /dev/null @@ -1,605 +0,0 @@ -## Inceptionv3 - -""" - inceptionv3_a(inplanes, pool_proj) - -Create an Inception-v3 style-A module -(ref: Fig. 5 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps - - `pool_proj`: the number of output feature maps for the pooling projection -""" -function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, pool_proj)...) - return Parallel(cat_channels, - branch1x1, branch5x5, branch3x3, branch_pool) -end - -""" - inceptionv3_b(inplanes) - -Create an Inception-v3 style-B module -(ref: Fig. 10 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; stride = 2)...) - branch_pool = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, - branch3x3_1, branch3x3_2, branch_pool) -end - -""" - inceptionv3_c(inplanes, inner_planes, n = 7) - -Create an Inception-v3 style-C module -(ref: Fig. 6 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps - - `inner_planes`: the number of output feature maps within each branch - - `n`: the "grid size" (kernel size) for the convolution layers -""" -function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) - return Parallel(cat_channels, - branch1x1, branch7x7_1, branch7x7_2, branch_pool) -end - -""" - inceptionv3_d(inplanes) - -Create an Inception-v3 style-D module -(ref: [pytorch](https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/inception.py#L322)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., - conv_norm((1, 7), 192, 192; pad = (0, 3))..., - conv_norm((7, 1), 192, 192; pad = (3, 0))..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch_pool = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, - branch3x3, branch7x7x3, branch_pool) -end - -""" - inceptionv3_e(inplanes) - -Create an Inception-v3 style-E module -(ref: Fig. 7 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., - conv_norm((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), inplanes, 192)...) - return Parallel(cat_channels, - branch1x1, - Chain(branch3x3_1, - Parallel(cat_channels, - branch3x3_1a, branch3x3_1b)), - Chain(branch3x3_2, - Parallel(cat_channels, - branch3x3_2a, branch3x3_2b)), - branch_pool) -end - -""" - inceptionv3(; nclasses = 1000) - -Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `nclasses`: the number of output classes -""" -function inceptionv3(; nclasses = 1000) - layer = Chain(Chain(conv_norm((3, 3), 3, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - MaxPool((3, 3); stride = 2), - conv_norm((1, 1), 64, 80)..., - conv_norm((3, 3), 80, 192)..., - MaxPool((3, 3); stride = 2), - inceptionv3_a(192, 32), - inceptionv3_a(256, 64), - inceptionv3_a(288, 64), - inceptionv3_b(288), - inceptionv3_c(768, 128), - inceptionv3_c(768, 160), - inceptionv3_c(768, 160), - inceptionv3_c(768, 192), - inceptionv3_d(768), - inceptionv3_e(1280), - inceptionv3_e(2048)), - Chain(AdaptiveMeanPool((1, 1)), - Dropout(0.2), - MLUtils.flatten, - Dense(2048, nclasses))) - return layer -end - -""" - Inceptionv3(; pretrain = false, nclasses = 1000) - -Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). -See also [`inceptionv3`](#). - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -!!! warning - - `Inceptionv3` does not currently support pretrained weights. -""" -struct Inceptionv3 - layers::Any -end - -function Inceptionv3(; pretrain = false, nclasses = 1000) - layers = inceptionv3(; nclasses = nclasses) - if pretrain - loadpretrain!(layers, "Inceptionv3") - end - return Inceptionv3(layers) -end - -@functor Inceptionv3 - -(m::Inceptionv3)(x) = m.layers(x) - -backbone(m::Inceptionv3) = m.layers[1] -classifier(m::Inceptionv3) = m.layers[2] - -## Inceptionv4 - -function mixed_3a() - return Parallel(cat_channels, - MaxPool((3, 3); stride = 2), - Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) -end - -function mixed_4a() - return Parallel(cat_channels, - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((3, 3), 64, 96)...), - Chain(conv_norm((1, 1), 160, 64)..., - conv_norm((1, 7), 64, 64; pad = (0, 3))..., - conv_norm((7, 1), 64, 64; pad = (3, 0))..., - conv_norm((3, 3), 64, 96)...)) -end - -function mixed_5a() - return Parallel(cat_channels, - Chain(conv_norm((3, 3), 192, 192; stride = 2)...), - MaxPool((3, 3); stride = 2)) -end - -function inceptionv4_a() - branch1 = Chain(conv_norm((1, 1), 384, 96)...) - branch2 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 384, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_a() - branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 384, 192)..., - conv_norm((3, 3), 192, 224; pad = 1)..., - conv_norm((3, 3), 224, 256; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_b() - branch1 = Chain(conv_norm((1, 1), 1024, 384)...) - branch2 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((1, 7), 192, 224; pad = (0, 3))..., - conv_norm((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((7, 1), 192, 192; pad = (0, 3))..., - conv_norm((1, 7), 192, 224; pad = (3, 0))..., - conv_norm((7, 1), 224, 224; pad = (0, 3))..., - conv_norm((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_b() - branch1 = Chain(conv_norm((1, 1), 1024, 192)..., - conv_norm((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1024, 256)..., - conv_norm((1, 7), 256, 256; pad = (0, 3))..., - conv_norm((7, 1), 256, 320; pad = (3, 0))..., - conv_norm((3, 3), 320, 320; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_c() - branch1 = Chain(conv_norm((1, 1), 1536, 256)...) - branch2 = Chain(conv_norm((1, 1), 1536, 384)..., - Parallel(cat_channels, - Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_norm((1, 1), 1536, 384)..., - conv_norm((3, 1), 384, 448; pad = (1, 0))..., - conv_norm((1, 3), 448, 512; pad = (0, 1))..., - Parallel(cat_channels, - Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -""" - inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Create an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - mixed_3a(), - mixed_4a(), - mixed_5a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - reduction_a(), # mixed_6a - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - reduction_b(), # mixed_7a - inceptionv4_c(), - inceptionv4_c(), - inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), - Dense(1536, nclasses)) - return Chain(body, head) -end - -""" - Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `Inceptionv4` does not currently support pretrained weights. -""" -struct Inceptionv4 - layers::Any -end - -function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, - nclasses = 1000) - layers = inceptionv4(; inchannels, dropout_rate, nclasses) - if pretrain - loadpretrain!(layers, "Inceptionv4") - end - return Inceptionv4(layers) -end - -@functor Inceptionv4 - -(m::Inceptionv4)(x) = m.layers(x) - -backbone(m::Inceptionv4) = m.layers[1] -classifier(m::Inceptionv4) = m.layers[2] - -## Inception-ResNetv2 - -function mixed_5b() - branch1 = Chain(conv_norm((1, 1), 192, 96)...) - branch2 = Chain(conv_norm((1, 1), 192, 48)..., - conv_norm((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_norm((1, 1), 192, 64)..., - conv_norm((3, 3), 64, 96; pad = 1)..., - conv_norm((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_norm((1, 1), 192, 64)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function block35(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 320, 32)...) - branch2 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_norm((1, 1), 320, 32)..., - conv_norm((3, 3), 32, 48; pad = 1)..., - conv_norm((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_norm((1, 1), 128, 320)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), - branch4, inputscale(scale; activation = relu)), +) -end - -function mixed_6a() - branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 320, 256)..., - conv_norm((3, 3), 256, 256; pad = 1)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function block17(scale = 1.0f0) - branch1 = Chain(conv_norm((1, 1), 1088, 192)...) - branch2 = Chain(conv_norm((1, 1), 1088, 128)..., - conv_norm((1, 7), 128, 160; pad = (0, 3))..., - conv_norm((7, 1), 160, 192; pad = (3, 0))...) - branch3 = Chain(conv_norm((1, 1), 384, 1088)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), - branch3, inputscale(scale; activation = relu)), +) -end - -function mixed_7a() - branch1 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_norm((1, 1), 1088, 256)..., - conv_norm((3, 3), 256, 288; pad = 1)..., - conv_norm((3, 3), 288, 320; stride = 2)...) - branch4 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_norm((1, 1), 2080, 192)...) - branch2 = Chain(conv_norm((1, 1), 2080, 192)..., - conv_norm((1, 3), 192, 224; pad = (0, 1))..., - conv_norm((3, 1), 224, 256; pad = (1, 0))...) - branch3 = Chain(conv_norm((1, 1), 448, 2080)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), - branch3, inputscale(scale; activation)), +) -end - -""" - inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an InceptionResNetv2 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., - conv_norm((3, 3), 32, 32)..., - conv_norm((3, 3), 32, 64; pad = 1)..., - MaxPool((3, 3); stride = 2), - conv_norm((3, 3), 64, 80)..., - conv_norm((3, 3), 80, 192)..., - MaxPool((3, 3); stride = 2), - mixed_5b(), - [block35(0.17f0) for _ in 1:10]..., - mixed_6a(), - [block17(0.10f0) for _ in 1:20]..., - mixed_7a(), - [block8(0.20f0) for _ in 1:9]..., - block8(; activation = relu), - conv_norm((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), - Dense(1536, nclasses)) - return Chain(body, head) -end - -""" - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an InceptionResNetv2 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `InceptionResNetv2` does not currently support pretrained weights. -""" -struct InceptionResNetv2 - layers::Any -end - -function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, - nclasses = 1000) - layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) - if pretrain - loadpretrain!(layers, "InceptionResNetv2") - end - return InceptionResNetv2(layers) -end - -@functor InceptionResNetv2 - -(m::InceptionResNetv2)(x) = m.layers(x) - -backbone(m::InceptionResNetv2) = m.layers[1] -classifier(m::InceptionResNetv2) = m.layers[2] - -## Xception - -""" - xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, - grow_at_start = true) - -Create an Xception block. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `inchannels`: The number of channels in the input. - - `outchannels`: number of output channels. - - `nrepeats`: number of repeats of depthwise separable convolution layers. - - `stride`: stride by which to downsample the input. - - `start_with_relu`: if true, start the block with a ReLU activation. - - `grow_at_start`: if true, increase the number of channels at the first convolution. -""" -function xception_block(inchannels, outchannels, nrepeats; stride = 1, - start_with_relu = true, - grow_at_start = true) - if outchannels != inchannels || stride != 1 - skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) - else - skip = [identity] - end - layers = [] - for i in 1:nrepeats - if grow_at_start - inc = i == 1 ? inchannels : outchannels - outc = outchannels - else - inc = inchannels - outc = i == nrepeats ? outchannels : inchannels - end - push!(layers, relu) - append!(layers, - depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, - use_bn = (false, false))) - push!(layers, BatchNorm(outc)) - end - layers = start_with_relu ? layers : layers[2:end] - push!(layers, MaxPool((3, 3); stride = stride, pad = 1)) - return Parallel(+, Chain(skip...), Chain(layers...)) -end - -""" - xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an Xception model. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_norm((3, 3), 32, 64; bias = false)..., - xception_block(64, 128, 2; stride = 2, start_with_relu = false), - xception_block(128, 256, 2; stride = 2), - xception_block(256, 728, 2; stride = 2), - [xception_block(728, 728, 3) for _ in 1:8]..., - xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), - Dense(2048, nclasses)) - return Chain(body, head) -end - -struct Xception - layers::Any -end - -""" - Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - -Creates an Xception model. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `Xception` does not currently support pretrained weights. -""" -function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) - layers = xception(; inchannels, dropout_rate, nclasses) - if pretrain - loadpretrain!(layers, "xception") - end - return Xception(layers) -end - -@functor Xception - -(m::Xception)(x) = m.layers(x) - -backbone(m::Xception) = m.layers[1] -classifier(m::Xception) = m.layers[2] diff --git a/src/convnets/googlenet.jl b/src/convnets/inception/googlenet.jl similarity index 99% rename from src/convnets/googlenet.jl rename to src/convnets/inception/googlenet.jl index 946d0d7f7..8a88ca943 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/inception/googlenet.jl @@ -16,15 +16,12 @@ Create an inception module for use in GoogLeNet """ function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj) branch1 = Chain(Conv((1, 1), inplanes => out_1x1)) - branch2 = Chain(Conv((1, 1), inplanes => red_3x3), Conv((3, 3), red_3x3 => out_3x3; pad = 1)) - branch3 = Chain(Conv((1, 1), inplanes => red_5x5), Conv((5, 5), red_5x5 => out_5x5; pad = 2)) branch4 = Chain(MaxPool((3, 3); stride = 1, pad = 1), Conv((1, 1), inplanes => pool_proj)) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) end @@ -83,6 +80,7 @@ See also [`googlenet`](#). struct GoogLeNet layers::Any end +@functor GoogLeNet function GoogLeNet(; pretrain = false, nclasses = 1000) layers = googlenet(; nclasses = nclasses) @@ -92,8 +90,6 @@ function GoogLeNet(; pretrain = false, nclasses = 1000) return GoogLeNet(layers) end -@functor GoogLeNet - (m::GoogLeNet)(x) = m.layers(x) backbone(m::GoogLeNet) = m.layers[1] diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl new file mode 100644 index 000000000..4b4b78706 --- /dev/null +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -0,0 +1,133 @@ +function mixed_5b() + branch1 = Chain(conv_norm((1, 1), 192, 96)...) + branch2 = Chain(conv_norm((1, 1), 192, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(conv_norm((1, 1), 192, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), 192, 64)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function block35(scale = 1.0f0) + branch1 = Chain(conv_norm((1, 1), 320, 32)...) + branch2 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 48; pad = 1)..., + conv_norm((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(conv_norm((1, 1), 128, 320)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), + branch4, inputscale(scale; activation = relu)), +) +end + +function mixed_6a() + branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 320, 256)..., + conv_norm((3, 3), 256, 256; pad = 1)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function block17(scale = 1.0f0) + branch1 = Chain(conv_norm((1, 1), 1088, 192)...) + branch2 = Chain(conv_norm((1, 1), 1088, 128)..., + conv_norm((1, 7), 128, 160; pad = (0, 3))..., + conv_norm((7, 1), 160, 192; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 384, 1088)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), + branch3, inputscale(scale; activation = relu)), +) +end + +function mixed_7a() + branch1 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; pad = 1)..., + conv_norm((3, 3), 288, 320; stride = 2)...) + branch4 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function block8(scale = 1.0f0; activation = identity) + branch1 = Chain(conv_norm((1, 1), 2080, 192)...) + branch2 = Chain(conv_norm((1, 1), 2080, 192)..., + conv_norm((1, 3), 192, 224; pad = (0, 1))..., + conv_norm((3, 1), 224, 256; pad = (1, 0))...) + branch3 = Chain(conv_norm((1, 1), 448, 2080)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), + branch3, inputscale(scale; activation)), +) +end + +""" + inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an InceptionResNetv2 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + MaxPool((3, 3); stride = 2), + conv_norm((3, 3), 64, 80)..., + conv_norm((3, 3), 80, 192)..., + MaxPool((3, 3); stride = 2), + mixed_5b(), + [block35(0.17f0) for _ in 1:10]..., + mixed_6a(), + [block17(0.10f0) for _ in 1:20]..., + mixed_7a(), + [block8(0.20f0) for _ in 1:9]..., + block8(; activation = relu), + conv_norm((1, 1), 2080, 1536)...) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(1536, nclasses)) + return Chain(body, head) +end + +""" + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an InceptionResNetv2 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `InceptionResNetv2` does not currently support pretrained weights. +""" +struct InceptionResNetv2 + layers::Any +end +@functor InceptionResNetv2 + +function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "InceptionResNetv2") + end + return InceptionResNetv2(layers) +end + +(m::InceptionResNetv2)(x) = m.layers(x) + +backbone(m::InceptionResNetv2) = m.layers[1] +classifier(m::InceptionResNetv2) = m.layers[2] diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl new file mode 100644 index 000000000..68b283838 --- /dev/null +++ b/src/convnets/inception/inceptionv3.jl @@ -0,0 +1,196 @@ +## Inceptionv3 + +""" + inceptionv3_a(inplanes, pool_proj) + +Create an Inception-v3 style-A module +(ref: Fig. 5 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps + - `pool_proj`: the number of output feature maps for the pooling projection +""" +function inceptionv3_a(inplanes, pool_proj) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) + branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, pool_proj)...) + return Parallel(cat_channels, + branch1x1, branch5x5, branch3x3, branch_pool) +end + +""" + inceptionv3_b(inplanes) + +Create an Inception-v3 style-B module +(ref: Fig. 10 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_b(inplanes) + branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; stride = 2)...) + branch_pool = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, + branch3x3_1, branch3x3_2, branch_pool) +end + +""" + inceptionv3_c(inplanes, inner_planes, n = 7) + +Create an Inception-v3 style-C module +(ref: Fig. 6 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps + - `inner_planes`: the number of output feature maps within each branch + - `n`: the "grid size" (kernel size) for the convolution layers +""" +function inceptionv3_c(inplanes, inner_planes, n = 7) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) + branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) + branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, 192)...) + return Parallel(cat_channels, + branch1x1, branch7x7_1, branch7x7_2, branch_pool) +end + +""" + inceptionv3_d(inplanes) + +Create an Inception-v3 style-D module +(ref: [pytorch](https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/inception.py#L322)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_d(inplanes) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((1, 7), 192, 192; pad = (0, 3))..., + conv_norm((7, 1), 192, 192; pad = (3, 0))..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch_pool = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, + branch3x3, branch7x7x3, branch_pool) +end + +""" + inceptionv3_e(inplanes) + +Create an Inception-v3 style-E module +(ref: Fig. 7 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_e(inplanes) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) + branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) + branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., + conv_norm((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, 192)...) + return Parallel(cat_channels, + branch1x1, + Chain(branch3x3_1, + Parallel(cat_channels, + branch3x3_1a, branch3x3_1b)), + Chain(branch3x3_2, + Parallel(cat_channels, + branch3x3_2a, branch3x3_2b)), + branch_pool) +end + +""" + inceptionv3(; nclasses = 1000) + +Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `nclasses`: the number of output classes +""" +function inceptionv3(; nclasses = 1000) + layer = Chain(Chain(conv_norm((3, 3), 3, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + MaxPool((3, 3); stride = 2), + conv_norm((1, 1), 64, 80)..., + conv_norm((3, 3), 80, 192)..., + MaxPool((3, 3); stride = 2), + inceptionv3_a(192, 32), + inceptionv3_a(256, 64), + inceptionv3_a(288, 64), + inceptionv3_b(288), + inceptionv3_c(768, 128), + inceptionv3_c(768, 160), + inceptionv3_c(768, 160), + inceptionv3_c(768, 192), + inceptionv3_d(768), + inceptionv3_e(1280), + inceptionv3_e(2048)), + Chain(AdaptiveMeanPool((1, 1)), + Dropout(0.2), + MLUtils.flatten, + Dense(2048, nclasses))) + return layer +end + +""" + Inceptionv3(; pretrain = false, nclasses = 1000) + +Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). +See also [`inceptionv3`](#). + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +!!! warning + + `Inceptionv3` does not currently support pretrained weights. +""" +struct Inceptionv3 + layers::Any +end + +function Inceptionv3(; pretrain = false, nclasses = 1000) + layers = inceptionv3(; nclasses = nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv3") + end + return Inceptionv3(layers) +end + +@functor Inceptionv3 + +(m::Inceptionv3)(x) = m.layers(x) + +backbone(m::Inceptionv3) = m.layers[1] +classifier(m::Inceptionv3) = m.layers[2] diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl new file mode 100644 index 000000000..bb03646ec --- /dev/null +++ b/src/convnets/inception/inceptionv4.jl @@ -0,0 +1,158 @@ +function mixed_3a() + return Parallel(cat_channels, + MaxPool((3, 3); stride = 2), + Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) +end + +function mixed_4a() + return Parallel(cat_channels, + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((3, 3), 64, 96)...), + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((1, 7), 64, 64; pad = (0, 3))..., + conv_norm((7, 1), 64, 64; pad = (3, 0))..., + conv_norm((3, 3), 64, 96)...)) +end + +function mixed_5a() + return Parallel(cat_channels, + Chain(conv_norm((3, 3), 192, 192; stride = 2)...), + MaxPool((3, 3); stride = 2)) +end + +function inceptionv4_a() + branch1 = Chain(conv_norm((1, 1), 384, 96)...) + branch2 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_a() + branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 384, 192)..., + conv_norm((3, 3), 192, 224; pad = 1)..., + conv_norm((3, 3), 224, 256; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_b() + branch1 = Chain(conv_norm((1, 1), 1024, 384)...) + branch2 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((1, 7), 192, 224; pad = (0, 3))..., + conv_norm((7, 1), 224, 256; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((7, 1), 192, 192; pad = (0, 3))..., + conv_norm((1, 7), 192, 224; pad = (3, 0))..., + conv_norm((7, 1), 224, 224; pad = (0, 3))..., + conv_norm((1, 7), 224, 256; pad = (3, 0))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_b() + branch1 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1024, 256)..., + conv_norm((1, 7), 256, 256; pad = (0, 3))..., + conv_norm((7, 1), 256, 320; pad = (3, 0))..., + conv_norm((3, 3), 320, 320; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_c() + branch1 = Chain(conv_norm((1, 1), 1536, 256)...) + branch2 = Chain(conv_norm((1, 1), 1536, 384)..., + Parallel(cat_channels, + Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) + branch3 = Chain(conv_norm((1, 1), 1536, 384)..., + conv_norm((3, 1), 384, 448; pad = (1, 0))..., + conv_norm((1, 3), 448, 512; pad = (0, 1))..., + Parallel(cat_channels, + Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +""" + inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Create an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + mixed_3a(), + mixed_4a(), + mixed_5a(), + inceptionv4_a(), + inceptionv4_a(), + inceptionv4_a(), + inceptionv4_a(), + reduction_a(), # mixed_6a + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + reduction_b(), # mixed_7a + inceptionv4_c(), + inceptionv4_c(), + inceptionv4_c()) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(1536, nclasses)) + return Chain(body, head) +end + +""" + Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `Inceptionv4` does not currently support pretrained weights. +""" +struct Inceptionv4 + layers::Any +end +@functor Inceptionv4 + +function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionv4(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv4") + end + return Inceptionv4(layers) +end + +(m::Inceptionv4)(x) = m.layers(x) + +backbone(m::Inceptionv4) = m.layers[1] +classifier(m::Inceptionv4) = m.layers[2] diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl new file mode 100644 index 000000000..6f4928385 --- /dev/null +++ b/src/convnets/inception/xception.jl @@ -0,0 +1,106 @@ +""" + xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, + grow_at_start = true) + +Create an Xception block. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `inchannels`: The number of channels in the input. + - `outchannels`: number of output channels. + - `nrepeats`: number of repeats of depthwise separable convolution layers. + - `stride`: stride by which to downsample the input. + - `start_with_relu`: if true, start the block with a ReLU activation. + - `grow_at_start`: if true, increase the number of channels at the first convolution. +""" +function xception_block(inchannels, outchannels, nrepeats; stride = 1, + start_with_relu = true, + grow_at_start = true) + if outchannels != inchannels || stride != 1 + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, + bias = false) + else + skip = [identity] + end + layers = [] + for i in 1:nrepeats + if grow_at_start + inc = i == 1 ? inchannels : outchannels + outc = outchannels + else + inc = inchannels + outc = i == nrepeats ? outchannels : inchannels + end + push!(layers, relu) + append!(layers, + depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, + use_bn = (false, false))) + push!(layers, BatchNorm(outc)) + end + layers = start_with_relu ? layers : layers[2:end] + push!(layers, MaxPool((3, 3); stride = stride, pad = 1)) + return Parallel(+, Chain(skip...), Chain(layers...)) +end + +""" + xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Xception model. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., + conv_norm((3, 3), 32, 64; bias = false)..., + xception_block(64, 128, 2; stride = 2, start_with_relu = false), + xception_block(128, 256, 2; stride = 2), + xception_block(256, 728, 2; stride = 2), + [xception_block(728, 728, 3) for _ in 1:8]..., + xception_block(728, 1024, 2; stride = 2, grow_at_start = false), + depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., + depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(2048, nclasses)) + return Chain(body, head) +end + +struct Xception + layers::Any +end +@functor Xception + +""" + Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Xception model. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `Xception` does not currently support pretrained weights. +""" +function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "xception") + end + return Xception(layers) +end + +(m::Xception)(x) = m.layers(x) + +backbone(m::Xception) = m.layers[1] +classifier(m::Xception) = m.layers[2] diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl deleted file mode 100644 index 06274c94e..000000000 --- a/src/convnets/mobilenet.jl +++ /dev/null @@ -1,339 +0,0 @@ -# MobileNetv1 - -""" - mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - fcsize = 1024, - nclasses = 1000) - -Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `dw`: Set true to use a depthwise separable convolution or false for regular convolution - + `o`: The number of output feature maps - + `s`: The stride of the convolutional kernel - + `r`: The number of time this configuration block is repeated - - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. - - `fcsize`: The intermediate fully-connected size between the convolution and final layers - - `nclasses`: The number of output classes -""" -function mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - fcsize = 1024, - nclasses = 1000) - layers = [] - for (dw, outch, stride, nrepeats) in config - outch = Int(outch * width_mult) - for _ in 1:nrepeats - layer = dw ? - depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outch, activation; stride = stride, - pad = 1, - bias = false) - append!(layers, layer) - inchannels = outch - end - end - - return Chain(Chain(layers), - Chain(GlobalMeanPool(), - MLUtils.flatten, - Dense(inchannels, fcsize, activation), - Dense(fcsize, nclasses))) -end - -const mobilenetv1_configs = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), -] - -""" - MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv1 model with the baseline configuration -([reference](https://arxiv.org/abs/1704.04861v1)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv1`](#). -""" -struct MobileNetv1 - layers::Any -end - -function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - if pretrain - loadpretrain!(layers, string("MobileNetv1")) - end - return MobileNetv1(layers) -end - -@functor MobileNetv1 - -(m::MobileNetv1)(x) = m.layers(x) - -backbone(m::MobileNetv1) = m.layers[1] -classifier(m::MobileNetv1) = m.layers[2] - -# MobileNetv2 - -""" - mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - -Create a MobileNetv2 model. -([reference](https://arxiv.org/abs/1801.04381)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer - + `c`: The number of output feature maps - + `n`: The number of times a block is repeated - + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: The number of output classes -""" -function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - # building first layer - inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) - layers = [] - append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) - # building inverted residual blocks - for (t, c, n, s, a) in configs - outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) - for i in 1:n - push!(layers, - invertedresidual(3, inplanes, inplanes * t, outplanes, a; - stride = i == 1 ? s : 1)) - inplanes = outplanes - end - end - # building last several layers - outplanes = (width_mult > 1) ? - _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : - max_width - return Chain(Chain(Chain(layers), - conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(outplanes, nclasses))) -end - -# Layer configurations for MobileNetv2 -const mobilenetv2_configs = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), -] - -# Model definition for MobileNetv2 -struct MobileNetv2 - layers::Any -end - -""" - MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv2 model with the specified configuration. -([reference](https://arxiv.org/abs/1801.04381)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv2`](#). -""" -function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv2")) - if pretrain - loadpretrain!(layers, string("MobileNetv2")) - end - return MobileNetv2(layers) -end - -@functor MobileNetv2 - -(m::MobileNetv2)(x) = m.layers(x) - -backbone(m::MobileNetv2) = m.layers[1] -classifier(m::MobileNetv2) = m.layers[2] - -# MobileNetv3 - -""" - mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - -Create a MobileNetv3 model. -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - - `configs`: a "list of tuples" configuration for each layer that details: - - + `k::Integer` - The size of the convolutional kernel - + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - + `t::Integer` - The number of output feature maps for a given block - + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - + `s::Integer` - The stride of the convolutional kernel - + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: the number of output classes -""" -function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - # building first layer - inplanes = _round_channels(16 * width_mult, 8) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, - bias = false)) - explanes = 0 - # building inverted residual blocks - for (k, t, c, r, a, s) in configs - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - invertedresidual(k, inplanes, explanes, outplanes, a; - stride = s, reduction = r)) - inplanes = outplanes - end - # building last several layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel - classifier = Chain(Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses)) - return Chain(Chain(Chain(layers), - conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) -end - -# Configurations for small and large mode for MobileNetv3 -mobilenetv3_configs = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) - -# Model definition for MobileNetv3 -struct MobileNetv3 - layers::Any -end - -""" - MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv3 model with the specified configuration. -([reference](https://arxiv.org/abs/1905.02244)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -# Arguments - - - `mode`: :small or :large for the size of the model (see paper). - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of channels in the input. - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -See also [`Metalhead.mobilenetv3`](#). -""" -function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, - pretrain = false, nclasses = 1000) - @assert mode in [:large, :small] "`mode` has to be either :large or :small" - max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, - nclasses) - if pretrain - loadpretrain!(layers, string("MobileNetv3", mode)) - end - return MobileNetv3(layers) -end - -@functor MobileNetv3 - -(m::MobileNetv3)(x) = m.layers(x) - -backbone(m::MobileNetv3) = m.layers[1] -classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl new file mode 100644 index 000000000..4add739b4 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -0,0 +1,102 @@ +""" + mobilenetv1(width_mult, config; + activation = relu, + inchannels = 3, + fcsize = 1024, + nclasses = 1000) + +Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `dw`: Set true to use a depthwise separable convolution or false for regular convolution + + `o`: The number of output feature maps + + `s`: The stride of the convolutional kernel + + `r`: The number of time this configuration block is repeated + - `activate`: The activation function to use throughout the network + - `inchannels`: The number of input channels. + - `fcsize`: The intermediate fully-connected size between the convolution and final layers + - `nclasses`: The number of output classes +""" +function mobilenetv1(width_mult, config; + activation = relu, + inchannels = 3, + fcsize = 1024, + nclasses = 1000) + layers = [] + for (dw, outch, stride, nrepeats) in config + outch = Int(outch * width_mult) + for _ in 1:nrepeats + layer = dw ? + depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; + stride = stride, pad = 1, bias = false) : + conv_norm((3, 3), inchannels, outch, activation; stride = stride, + pad = 1, + bias = false) + append!(layers, layer) + inchannels = outch + end + end + + return Chain(Chain(layers), + Chain(GlobalMeanPool(), + MLUtils.flatten, + Dense(inchannels, fcsize, activation), + Dense(fcsize, nclasses))) +end + +const mobilenetv1_configs = [ + # dw, c, s, r + (false, 32, 2, 1), + (true, 64, 1, 1), + (true, 128, 2, 1), + (true, 128, 1, 1), + (true, 256, 2, 1), + (true, 256, 1, 1), + (true, 512, 2, 1), + (true, 512, 1, 5), + (true, 1024, 2, 1), + (true, 1024, 1, 1), +] + +""" + MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv1 model with the baseline configuration +([reference](https://arxiv.org/abs/1704.04861v1)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of input channels. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv1`](#). +""" +struct MobileNetv1 + layers::Any +end +@functor MobileNetv1 + +function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, + nclasses = 1000) + layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv1")) + end + return MobileNetv1(layers) +end + +(m::MobileNetv1)(x) = m.layers(x) + +backbone(m::MobileNetv1) = m.layers[1] +classifier(m::MobileNetv1) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl new file mode 100644 index 000000000..21c017b42 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -0,0 +1,97 @@ +""" + mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) + +Create a MobileNetv2 model. +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer + + `c`: The number of output feature maps + + `n`: The number of times a block is repeated + + `s`: The stride of the convolutional kernel + + `a`: The activation function used in the bottleneck layer + - `inchannels`: The number of input channels. + - `max_width`: The maximum number of feature maps in any layer of the network + - `nclasses`: The number of output classes +""" +function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) + # building first layer + inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) + layers = [] + append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + # building inverted residual blocks + for (t, c, n, s, a) in configs + outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) + for i in 1:n + push!(layers, + invertedresidual(3, inplanes, inplanes * t, outplanes, a; + stride = i == 1 ? s : 1)) + inplanes = outplanes + end + end + # building last several layers + outplanes = (width_mult > 1) ? + _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : + max_width + return Chain(Chain(Chain(layers), + conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)...), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, + Dense(outplanes, nclasses))) +end + +# Layer configurations for MobileNetv2 +const mobilenetv2_configs = [ + # t, c, n, s, a + (1, 16, 1, 1, relu6), + (6, 24, 2, 2, relu6), + (6, 32, 3, 2, relu6), + (6, 64, 4, 2, relu6), + (6, 96, 3, 1, relu6), + (6, 160, 3, 2, relu6), + (6, 320, 1, 1, relu6), +] + +# Model definition for MobileNetv2 +struct MobileNetv2 + layers::Any +end +@functor MobileNetv2 + +""" + MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv2 model with the specified configuration. +([reference](https://arxiv.org/abs/1801.04381)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of input channels. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv2`](#). +""" +function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, + nclasses = 1000) + layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) + pretrain && loadpretrain!(layers, string("MobileNetv2")) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end + return MobileNetv2(layers) +end + +(m::MobileNetv2)(x) = m.layers(x) + +backbone(m::MobileNetv2) = m.layers[1] +classifier(m::MobileNetv2) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl new file mode 100644 index 000000000..6bc444407 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -0,0 +1,129 @@ +""" + mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) + +Create a MobileNetv3 model. +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + + - `configs`: a "list of tuples" configuration for each layer that details: + + + `k::Integer` - The size of the convolutional kernel + + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer + + `t::Integer` - The number of output feature maps for a given block + + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + + `s::Integer` - The stride of the convolutional kernel + + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) + - `inchannels`: The number of input channels. + - `max_width`: The maximum number of feature maps in any layer of the network + - `nclasses`: the number of output classes +""" +function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) + # building first layer + inplanes = _round_channels(16 * width_mult, 8) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, + bias = false)) + explanes = 0 + # building inverted residual blocks + for (k, t, c, r, a, s) in configs + # inverted residual layers + outplanes = _round_channels(c * width_mult, 8) + explanes = _round_channels(inplanes * t, 8) + push!(layers, + invertedresidual(k, inplanes, explanes, outplanes, a; + stride = s, reduction = r)) + inplanes = outplanes + end + # building last several layers + output_channel = max_width + output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : + output_channel + classifier = Chain(Dense(explanes, output_channel, hardswish), + Dropout(0.2), + Dense(output_channel, nclasses)) + return Chain(Chain(Chain(layers), + conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)...), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) +end + +# Configurations for small and large mode for MobileNetv3 +mobilenetv3_configs = Dict(:small => [ + # k, t, c, SE, a, s + (3, 1, 16, 4, relu, 2), + (3, 4.5, 24, nothing, relu, 2), + (3, 3.67, 24, nothing, relu, 1), + (5, 4, 40, 4, hardswish, 2), + (5, 6, 40, 4, hardswish, 1), + (5, 6, 40, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 2), + (5, 6, 96, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 1), + ], + :large => [ + # k, t, c, SE, a, s + (3, 1, 16, nothing, relu, 1), + (3, 4, 24, nothing, relu, 2), + (3, 3, 24, nothing, relu, 1), + (5, 3, 40, 4, relu, 2), + (5, 3, 40, 4, relu, 1), + (5, 3, 40, 4, relu, 1), + (3, 6, 80, nothing, hardswish, 2), + (3, 2.5, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 2), + (5, 6, 160, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 1), + ]) + +# Model definition for MobileNetv3 +struct MobileNetv3 + layers::Any +end +@functor MobileNetv3 + +""" + MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv3 model with the specified configuration. +([reference](https://arxiv.org/abs/1905.02244)). +Set `pretrain = true` to load the model with pre-trained weights for ImageNet. + +# Arguments + + - `mode`: :small or :large for the size of the model (see paper). + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of channels in the input. + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +See also [`Metalhead.mobilenetv3`](#). +""" +function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, + pretrain = false, nclasses = 1000) + @assert mode in [:large, :small] "`mode` has to be either :large or :small" + max_width = (mode == :large) ? 1280 : 1024 + layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, + nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv3", mode)) + end + return MobileNetv3(layers) +end + +(m::MobileNetv3)(x) = m.layers(x) + +backbone(m::MobileNetv3) = m.layers[1] +classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index ad60814d4..9af497844 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -32,9 +32,8 @@ function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activati norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = expansion_factor(basicblock) first_planes = planes ÷ reduction_factor - outplanes = planes * expansion + outplanes = planes * expansion_factor(basicblock) conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm, stride, pad = 1, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm, @@ -81,10 +80,9 @@ function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor - outplanes = planes * expansion + outplanes = planes * expansion_factor(bottleneck) conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm, @@ -322,8 +320,8 @@ function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; end # block-layer configurations for ResNet-like models -const resnet_config = Dict(18 => (basicblock, [2, 2, 2, 2]), - 34 => (basicblock, [3, 4, 6, 3]), - 50 => (bottleneck, [3, 4, 6, 3]), - 101 => (bottleneck, [3, 4, 23, 3]), - 152 => (bottleneck, [3, 8, 36, 3])) +const resnet_configs = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index ffbee32dc..7bebb0873 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -1,9 +1,3 @@ -const resnet_shortcuts = Dict(18 => [:A, :B, :B, :B], - 34 => [:A, :B, :B, :B], - 50 => [:B, :B, :B, :B], - 101 => [:B, :B, :B, :B], - 152 => [:B, :B, :B, :B]) - """ ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) @@ -28,17 +22,16 @@ struct ResNet end @functor ResNet -(m::ResNet)(x) = m.layers(x) - function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [18, 34, 50, 101, 152] - "Invalid depth. Must be one of [18, 34, 50, 101, 152]" - layers = resnet(resnet_config[depth]..., resnet_shortcuts[depth]; inchannels, nclasses) + _checkconfig(depth, keys(resnet_configs)) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses) if pretrain loadpretrain!(layers, string("ResNet", depth)) end return ResNet(layers) end +(m::ResNet)(x) = m.layers(x) + backbone(m::ResNet) = m.layers[1] classifier(m::ResNet) = m.layers[2] diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index cee2e4757..47e81d44d 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -29,9 +29,8 @@ end function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width) + _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 26bfbb6c1..ae25a73e8 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -6,7 +6,7 @@ Creates a SEResNet model with the specified depth. # Arguments - - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet - `inchannels`: the number of input channels. - `nclasses`: the number of output classes @@ -25,9 +25,8 @@ end (m::SEResNet)(x) = m.layers(x) function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, + _checkconfig(depth, keys(resnet_configs)) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses, attn_fn = planes -> squeeze_excite(planes)) if pretrain loadpretrain!(layers, string("SEResNet", depth)) @@ -69,9 +68,8 @@ end function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - @assert depth in [50, 101, 152] - "Invalid depth. Must be one of [50, 101, 152]" - layers = resnet(resnet_config[depth]...; inchannels, nclasses, cardinality, base_width, + _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) + layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width, attn_fn = planes -> squeeze_excite(planes)) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index df458f9ff..abcdd63f8 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -15,11 +15,7 @@ function fire(inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes) branch_1 = Conv((1, 1), inplanes => squeeze_planes, relu) branch_2 = Conv((1, 1), squeeze_planes => expand1x1_planes, relu) branch_3 = Conv((3, 3), squeeze_planes => expand3x3_planes, relu; pad = 1) - - return Chain(branch_1, - Parallel(cat_channels, - branch_2, - branch_3)) + return Chain(branch_1, Parallel(cat_channels, branch_2, branch_3)) end """ @@ -29,24 +25,22 @@ Create a SqueezeNet ([reference](https://arxiv.org/abs/1602.07360v4)). """ function squeezenet() - layers = Chain(Chain(Conv((3, 3), 3 => 64, relu; stride = 2), - MaxPool((3, 3); stride = 2), - fire(64, 16, 64, 64), - fire(128, 16, 64, 64), - MaxPool((3, 3); stride = 2), - fire(128, 32, 128, 128), - fire(256, 32, 128, 128), - MaxPool((3, 3); stride = 2), - fire(256, 48, 192, 192), - fire(384, 48, 192, 192), - fire(384, 64, 256, 256), - fire(512, 64, 256, 256), - Dropout(0.5), - Conv((1, 1), 512 => 1000, relu)), - AdaptiveMeanPool((1, 1)), - MLUtils.flatten) - - return layers + return Chain(Chain(Conv((3, 3), 3 => 64, relu; stride = 2), + MaxPool((3, 3); stride = 2), + fire(64, 16, 64, 64), + fire(128, 16, 64, 64), + MaxPool((3, 3); stride = 2), + fire(128, 32, 128, 128), + fire(256, 32, 128, 128), + MaxPool((3, 3); stride = 2), + fire(256, 48, 192, 192), + fire(384, 48, 192, 192), + fire(384, 64, 256, 256), + fire(512, 64, 256, 256), + Dropout(0.5), + Conv((1, 1), 512 => 1000, relu)), + AdaptiveMeanPool((1, 1)), + MLUtils.flatten) end """ @@ -65,6 +59,7 @@ See also [`squeezenet`](#). struct SqueezeNet layers::Any end +@functor SqueezeNet function SqueezeNet(; pretrain = false) layers = squeezenet() @@ -74,8 +69,6 @@ function SqueezeNet(; pretrain = false) return SqueezeNet(layers) end -@functor SqueezeNet - (m::SqueezeNet)(x) = m.layers(x) backbone(m::SqueezeNet) = m.layers[1] diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index c8a2e6344..3a1a8ac10 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -99,15 +99,15 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr return Chain(Chain(conv), class) end -const vgg_conv_config = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], - :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], - :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], - :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) +const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], + :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], + :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], + :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) -const vgg_config = Dict(11 => :A, - 13 => :B, - 16 => :D, - 19 => :E) +const vgg_configs = Dict(11 => :A, + 13 => :B, + 16 => :D, + 19 => :E) struct VGG layers::Any @@ -153,8 +153,8 @@ See also [`VGG`](#). - `pretrain`: set to `true` to load pre-trained model weights for ImageNet """ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000) - @assert depth in keys(vgg_config) "depth must be from one in $(sort(collect(keys(vgg_config))))" - model = VGG((224, 224); config = vgg_conv_config[vgg_config[depth]], + _checkconfig(depth, keys(vgg_configs)) + model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]], inchannels = 3, batchnorm = batchnorm, nclasses = nclasses, diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 8ad47af4a..b433fbbc9 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -1,6 +1,7 @@ # Generates the mask to be used for `DropBlock` @inline function _dropblock_mask(rng, x, gamma, clipped_block_size) - block_mask = Flux.f32(rand_like(rng, x) .< gamma) + block_mask = rand_like(rng, x) + block_mask .= block_mask .< gamma return 1 .- maxpool(block_mask, (clipped_block_size, clipped_block_size); stride = 1, pad = clipped_block_size ÷ 2) end @@ -35,7 +36,7 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, bl gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / ((W - block_size + 1) * (H - block_size + 1)) block_mask = dropblock_mask(rng, x, gamma, clipped_block_size) - normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6)) + normalize_scale = length(block_mask) / sum(block_mask) .+ T(1e-6) return x .* block_mask .* normalize_scale end @@ -55,21 +56,17 @@ mutable struct DropBlock{F, R <: AbstractRNG} active::Union{Bool, Nothing} rng::R end - @functor DropBlock trainable(a::DropBlock) = (;) -function _dropblock_checks(x::T, drop_block_prob, gamma_scale) where {T} +function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_scale) @assert 0 ≤ drop_block_prob ≤ 1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" @assert 0 ≤ gamma_scale ≤ 1 "gamma_scale must be between 0 and 1, got $gamma_scale" - if !(T <: AbstractArray) - throw(ArgumentError("x must be an `AbstractArray`")) - end - if ndims(x) != 4 - throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`")) - end +end +function _dropblock_checks(x, drop_block_prob, gamma_scale) + throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) end ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_scale) diff --git a/src/mixers/core.jl b/src/mixers/core.jl new file mode 100644 index 000000000..6a55f048e --- /dev/null +++ b/src/mixers/core.jl @@ -0,0 +1,43 @@ +""" + mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., + depth = 12, nclasses = 1000, kwargs...) + +Creates a model with the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)). + +# Arguments + + - `block`: the type of mixer block to use in the model - architecture dependent + (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) + - `imsize`: the size of the input image + - `inchannels`: the number of input channels + - `norm_layer`: the normalization layer to use in the model + - `patch_size`: the size of the patches + - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) + - `drop_path_rate`: Stochastic depth rate + - `depth`: the number of blocks in the model + - `nclasses`: number of output classes + - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if + not specified. +""" +function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, + norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), + embedplanes = 512, drop_path_rate = 0.0, + depth = 12, nclasses = 1000, kwargs...) + npatches = prod(imsize .÷ patch_size) + dp_rates = linear_scheduler(drop_path_rate; depth) + layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), + Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], + kwargs...) + for i in 1:depth])) + classification_head = Chain(norm_layer(embedplanes), seconddimmean, + Dense(embedplanes, nclasses)) + return Chain(layers, classification_head) +end + +# Configurations for MLPMixer models +mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), + :base => Dict(:depth => 12, :planes => 768), + :large => Dict(:depth => 24, :planes => 1024), + :huge => Dict(:depth => 32, :planes => 1280)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl new file mode 100644 index 000000000..4e681e9b4 --- /dev/null +++ b/src/mixers/gmlp.jl @@ -0,0 +1,110 @@ +""" + SpatialGatingUnit(norm, proj) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `norm`: the normalisation layer to use + - `proj`: the projection layer to use +""" +struct SpatialGatingUnit{T, F} + norm::T + proj::F +end +@functor SpatialGatingUnit + +""" + SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `norm_layer`: the normalisation layer to use +""" +function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) + gateplanes = planes ÷ 2 + norm = norm_layer(gateplanes) + proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) + return SpatialGatingUnit(norm, proj) +end + +function (m::SpatialGatingUnit)(x) + u, v = chunk(x, 2; dims = 1) + v = m.norm(v) + v = m.proj(permutedims(v, (2, 1, 3))) + return u .* permutedims(v, (2, 1, 3)) +end + +""" + spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, + norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, + activation = gelu) + +Creates a feedforward block based on the gMLP model architecture described in the paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block + - `norm_layer`: the normalisation layer to use + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks +""" +function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, + mlp_layer = gated_mlp_block, dropout_rate = 0.0, + drop_path_rate = 0.0, + activation = gelu) + channelplanes = Int(mlp_ratio * planes) + sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) + return SkipConnection(Chain(norm_layer(planes), + mlp_layer(sgu, planes, channelplanes; activation, + dropout_rate), + DropPath(drop_path_rate)), +) +end + +struct gMLP + layers::Any +end +@functor gMLP + +""" + gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) + +Creates a model with the gMLP architecture. +([reference](https://arxiv.org/abs/2105.08050)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(mixer_configs)) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, + patch_size, embedplanes, drop_path_rate, depth, nclasses) + return gMLP(layers) +end + +(m::gMLP)(x) = m.layers(x) + +backbone(m::gMLP) = m.layers[1] +classifier(m::gMLP) = m.layers[2] diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl new file mode 100644 index 000000000..e3da17a23 --- /dev/null +++ b/src/mixers/mlpmixer.jl @@ -0,0 +1,69 @@ +""" + mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout_rate = 0., drop_path_rate = 0., activation = gelu) + +Creates a feedforward block for the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP + and/or the channel mixing MLP as a ratio to the number of planes in the block. + - `mlp_layer`: the MLP layer to use in the block + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks +""" +function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] + return Chain(SkipConnection(Chain(LayerNorm(planes), + swapdims((2, 1, 3)), + mlp_layer(npatches, tokenplanes; activation, + dropout_rate), + swapdims((2, 1, 3)), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(LayerNorm(planes), + mlp_layer(planes, channelplanes; activation, + dropout_rate), + DropPath(drop_path_rate)), +)) +end + +struct MLPMixer + layers::Any +end +@functor MLPMixer + +""" + MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) + +Creates a model with the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(mixer_configs)) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, + nclasses) + return MLPMixer(layers) +end + +(m::MLPMixer)(x) = m.layers(x) + +backbone(m::MLPMixer) = m.layers[1] +classifier(m::MLPMixer) = m.layers[2] diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl new file mode 100644 index 000000000..38163702c --- /dev/null +++ b/src/mixers/resmlp.jl @@ -0,0 +1,72 @@ +""" + resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, + activation = gelu, λ = 1e-4) + +Creates a block for the ResMixer architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block + - `mlp_layer`: the MLP block to use + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks + - `λ`: initialisation constant for the LayerScale +""" +function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, + λ = 1e-4) + return Chain(SkipConnection(Chain(Flux.Scale(planes), + swapdims((2, 1, 3)), + Dense(npatches, npatches), + swapdims((2, 1, 3)), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(Flux.Scale(planes), + mlp_layer(planes, Int(mlp_ratio * planes); + dropout_rate, + activation), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +)) +end + +struct ResMLP + layers::Any +end +@functor ResMLP + +""" + ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + +Creates a model with the ResMLP architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(mixer_configs)) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, + drop_path_rate, depth, nclasses) + return ResMLP(layers) +end + +(m::ResMLP)(x) = m.layers(x) + +backbone(m::ResMLP) = m.layers[1] +classifier(m::ResMLP) = m.layers[2] diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl deleted file mode 100644 index ac66be28b..000000000 --- a/src/other/mlpmixer.jl +++ /dev/null @@ -1,302 +0,0 @@ -""" - mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout_rate = 0., drop_path_rate = 0., activation = gelu) - -Creates a feedforward block for the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP - and/or the channel mixing MLP as a ratio to the number of planes in the block. - - `mlp_layer`: the MLP layer to use in the block - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks -""" -function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) - tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] - return Chain(SkipConnection(Chain(LayerNorm(planes), - swapdims((2, 1, 3)), - mlp_layer(npatches, tokenplanes; activation, - dropout_rate), - swapdims((2, 1, 3)), - DropPath(drop_path_rate)), +), - SkipConnection(Chain(LayerNorm(planes), - mlp_layer(planes, channelplanes; activation, - dropout_rate), - DropPath(drop_path_rate)), +)) -end - -""" - mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., - depth = 12, nclasses = 1000, kwargs...) - -Creates a model with the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)). - -# Arguments - - - `block`: the type of mixer block to use in the model - architecture dependent - (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) - - `imsize`: the size of the input image - - `inchannels`: the number of input channels - - `norm_layer`: the normalization layer to use in the model - - `patch_size`: the size of the patches - - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) - - `drop_path_rate`: Stochastic depth rate - - `depth`: the number of blocks in the model - - `nclasses`: number of output classes - - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if - not specified. -""" -function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, - norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), - embedplanes = 512, drop_path_rate = 0.0, - depth = 12, nclasses = 1000, kwargs...) - npatches = prod(imsize .÷ patch_size) - dp_rates = linear_scheduler(drop_path_rate; depth) - layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), - Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], - kwargs...) - for i in 1:depth])) - - classification_head = Chain(norm_layer(embedplanes), seconddimmean, - Dense(embedplanes, nclasses)) - return Chain(layers, classification_head) -end - -# Configurations for MLPMixer models -mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), - :base => Dict(:depth => 12, :planes => 768), - :large => Dict(:depth => 24, :planes => 1024), - :huge => Dict(:depth => 32, :planes => 1280)) - -struct MLPMixer - layers::Any -end - -""" - MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) - -Creates a model with the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, - nclasses) - return MLPMixer(layers) -end - -@functor MLPMixer - -(m::MLPMixer)(x) = m.layers(x) - -backbone(m::MLPMixer) = m.layers[1] -classifier(m::MLPMixer) = m.layers[2] - -""" - resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, - activation = gelu, λ = 1e-4) - -Creates a block for the ResMixer architecture. -([reference](https://arxiv.org/abs/2105.03404)). - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number - of planes in the block - - `mlp_layer`: the MLP block to use - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks - - `λ`: initialisation constant for the LayerScale -""" -function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, - λ = 1e-4) - return Chain(SkipConnection(Chain(Flux.Scale(planes), - swapdims((2, 1, 3)), - Dense(npatches, npatches), - swapdims((2, 1, 3)), - LayerScale(planes, λ), - DropPath(drop_path_rate)), +), - SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); - dropout_rate, - activation), - LayerScale(planes, λ), - DropPath(drop_path_rate)), +)) -end - -struct ResMLP - layers::Any -end - -""" - ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), - drop_path_rate = 0., nclasses = 1000) - -Creates a model with the ResMLP architecture. -([reference](https://arxiv.org/abs/2105.03404)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, - drop_path_rate, depth, nclasses) - return ResMLP(layers) -end - -@functor ResMLP - -(m::ResMLP)(x) = m.layers(x) - -backbone(m::ResMLP) = m.layers[1] -classifier(m::ResMLP) = m.layers[2] - -""" - SpatialGatingUnit(norm, proj) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `norm`: the normalisation layer to use - - `proj`: the projection layer to use -""" -struct SpatialGatingUnit{T, F} - norm::T - proj::F -end - -""" - SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `norm_layer`: the normalisation layer to use -""" -function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - gateplanes = planes ÷ 2 - norm = norm_layer(gateplanes) - proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) - return SpatialGatingUnit(norm, proj) -end - -@functor SpatialGatingUnit - -function (m::SpatialGatingUnit)(x) - u, v = chunk(x, 2; dims = 1) - v = m.norm(v) - v = m.proj(permutedims(v, (2, 1, 3))) - return u .* permutedims(v, (2, 1, 3)) -end - -""" - spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, - activation = gelu) - -Creates a feedforward block based on the gMLP model architecture described in the paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number - of planes in the block - - `norm_layer`: the normalisation layer to use - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks -""" -function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, dropout_rate = 0.0, - drop_path_rate = 0.0, - activation = gelu) - channelplanes = Int(mlp_ratio * planes) - sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) - return SkipConnection(Chain(norm_layer(planes), - mlp_layer(sgu, planes, channelplanes; activation, - dropout_rate), - DropPath(drop_path_rate)), +) -end - -struct gMLP - layers::Any -end - -""" - gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) - -Creates a model with the gMLP architecture. -([reference](https://arxiv.org/abs/2105.08050)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, - patch_size, embedplanes, drop_path_rate, depth, nclasses) - return gMLP(layers) -end - -@functor gMLP - -(m::gMLP)(x) = m.layers(x) - -backbone(m::gMLP) = m.layers[1] -classifier(m::gMLP) = m.layers[2] diff --git a/src/utilities.jl b/src/utilities.jl index b420efd0b..981777228 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -18,7 +18,7 @@ function in [`resnet`](#). See also [`reluadd`](#). """ -addact(activation = relu, xs...) = activation(sum(tuple(xs...))) +addact(activation = relu, xs...) = activation(sum(xs)) """ actadd(activation = relu, xs...) @@ -29,7 +29,7 @@ function to them. Useful as the `connection` argument for the block function in See also [`addrelu`](#). """ -actadd(activation = relu, xs...) = sum(activation.(tuple(xs...))) +actadd(activation = relu, xs...) = sum(activation.(x) for x in xs) """ cat_channels(x, y, zs...) @@ -68,5 +68,11 @@ end Returns the dropout rates for a given depth using the linear scaling rule. """ function linear_scheduler(drop_rate = 0.0; depth, start_value = 0.0) - return LinRange{Float32}(start_value, drop_rate, depth) + return LinRange(start_value, drop_rate, depth) +end + +# Utility function for depth and configuration checks in models +function _checkconfig(config, configs) + @assert config in configs + return "Invalid configuration. Must be one of $(sort(collect(configs)))." end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 856b64697..93eba09ee 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -94,10 +94,11 @@ See also [`Metalhead.vit`](#). struct ViT layers::Any end +@functor ViT function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) - @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" + _checkconfig(mode, keys(vit_configs)) kwargs = vit_configs[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) return ViT(layers) @@ -107,5 +108,3 @@ end backbone(m::ViT) = m.layers[1] classifier(m::ViT) = m.layers[2] - -@functor ViT diff --git a/test/convnets.jl b/test/convnets.jl index 601f51421..860ac422b 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -46,7 +46,7 @@ end (dropout_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), ] @testset for drop_rates in drop_list - m = Metalhead.resnet(block_fn, layers; drop_rates) + m = Metalhead.resnet(block_fn, layers; drop_rates...) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() @@ -75,7 +75,7 @@ end end @testset "SEResNet" begin - @testset for depth in [50, 101, 152] + @testset for depth in [18, 34, 50, 101, 152] m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) if (SEResNet, depth) in PRETRAINED_MODELS diff --git a/test/other.jl b/test/mixers.jl similarity index 76% rename from test/other.jl rename to test/mixers.jl index df97d4f5f..5e2e8cf27 100644 --- a/test/other.jl +++ b/test/mixers.jl @@ -1,5 +1,5 @@ @testset "MLPMixer" begin - @testset for mode in [:small, :base] # :large, # :huge] + @testset for mode in [:small, :base, :large, :huge] @testset for drop_path_rate in [0.0, 0.5] m = MLPMixer(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -10,7 +10,7 @@ end @testset "ResMLP" begin - @testset for mode in [:small, :base] # :large, # :huge] + @testset for mode in [:small, :base, :large, :huge] @testset for drop_path_rate in [0.0, 0.5] m = ResMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -21,7 +21,7 @@ end end @testset "gMLP" begin - @testset for mode in [:small, :base] # :large, # :huge] + @testset for mode in [:small, :base, :large, :huge] @testset for drop_path_rate in [0.0, 0.5] m = gMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) diff --git a/test/runtests.jl b/test/runtests.jl index 1a8c77f25..eaef97472 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,9 +61,9 @@ end GC.safepoint() GC.gc() -# Other tests -@testset verbose = true "Other" begin - include("other.jl") +# Mixer tests +@testset verbose = true "Mixers" begin + include("mixers.jl") end GC.safepoint() @@ -71,5 +71,5 @@ GC.gc() # ViT tests @testset verbose = true "ViTs" begin - include("vit-based.jl") + include("vits.jl") end diff --git a/test/vit-based.jl b/test/vits.jl similarity index 65% rename from test/vit-based.jl rename to test/vits.jl index e889b07be..13733ddec 100644 --- a/test/vit-based.jl +++ b/test/vits.jl @@ -1,5 +1,5 @@ @testset "ViT" begin - for mode in [:small, :base, :large] # :tiny, #,:huge, :giant, :gigantic] + for mode in [:tiny, :small, :base, :large, :huge] #:giant, #:gigantic m = ViT(mode) @test size(m(x_256)) == (1000, 1) @test gradtest(m, x_256) From 8c9f73f3947f220c40838ad4ca5e71d6af635789 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 27 Jul 2022 08:09:19 +0530 Subject: [PATCH 53/64] Remove templating for now Pass through kwargs using `NamedTuples` instead of closures --- src/convnets/resnets/core.jl | 183 +++++++++++++++++------------------ src/layers/drop.jl | 4 +- 2 files changed, 91 insertions(+), 96 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 9af497844..ebfcc7496 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -15,12 +15,9 @@ Creates a basic ResNet block. - `downsample`: the downsampling function to use - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `dilation`: the dilation of the second convolution. - - `first_dilation`: the dilation of the first convolution. - - `activation`: the activation function to use. - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. + a block. See [`addact`](#) and [`actadd`](#) for an example. + - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -28,19 +25,25 @@ Creates a basic ResNet block. function and passed in. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ -function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, +function basicblock(inplanes::Integer, planes::Integer; downsample_fns, + stride::Integer = 1, reduction_factor::Integer = 1, + connection = addact, activation = relu, + norm_layer = BatchNorm, prenorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) + expansion = expansion_factor(basicblock) first_planes = planes ÷ reduction_factor - outplanes = planes * expansion_factor(basicblock) + outplanes = planes * expansion conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm, stride, pad = 1, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm, pad = 1, bias = false) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] - return Chain(filter!(!=(identity), layers)...) + downsample = downsample_block(downsample_fns, inplanes, planes, expansion; + stride, norm_layer, prenorm) + return Parallel(connection$activation, Chain(filter!(!=(identity), layers)...), + downsample) end expansion_factor(::typeof(basicblock)) = 1 @@ -63,11 +66,9 @@ Creates a bottleneck ResNet block. - `base_width`: the number of output feature maps for each convolutional group. - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. + a block. See [`addact`](#) and [`actadd`](#) for an example. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -75,23 +76,28 @@ Creates a bottleneck ResNet block. function and passed in. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ -function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, - reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, +function bottleneck(inplanes::Integer, planes::Integer; downsample_fns, stride::Integer = 1, + cardinality::Integer = 1, base_width::Integer = 64, + reduction_factor::Integer = 1, connection = addact, activation = relu, + norm_layer = BatchNorm, prenorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) + expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor - outplanes = planes * expansion_factor(bottleneck) + outplanes = planes * expansion conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm, stride, pad = 1, groups = cardinality, bias = false) conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, prenorm, bias = false) + downsample = downsample_block(downsample_fns, inplanes, planes, expansion; + stride, norm_layer, prenorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] - return Chain(filter!(!=(identity), layers)...) + return Parallel(connection$activation, Chain(filter!(!=(identity), layers)...), + downsample) end expansion_factor(::typeof(bottleneck)) = 4 @@ -126,6 +132,12 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) end end +function downsample_block(downsample_fns, inplanes, planes, expansion; stride, kwargs...) + down_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_fns[1] : + downsample_fns[2] + return down_fn(inplanes, planes * expansion; stride, kwargs...) +end + # Shortcut configurations for the ResNet models const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), @@ -136,7 +148,7 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), # specified as a `Vector` of `Symbol`s. This is used to make the downsample # `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is # already an `NTuple{2}` of functions, it is returned unchanged. -function _make_downsample_fns(vec::Vector{<:Symbol}, layers) +function _make_downsample_fns(vec::Vector{<:Symbol}, block_repeats) downs = [] for i in vec @assert i in keys(shortcut_dict) @@ -145,19 +157,21 @@ function _make_downsample_fns(vec::Vector{<:Symbol}, layers) end return downs end -function _make_downsample_fns(sym::Symbol, layers) +function _make_downsample_fns(sym::Symbol, block_repeats) @assert sym in keys(shortcut_dict) "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - return collect(shortcut_dict[sym] for _ in 1:length(layers)) + return collect(shortcut_dict[sym] for _ in 1:length(block_repeats)) end -_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec -_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) +_make_downsample_fns(vec::Vector{<:NTuple{2}}, block_repeats) = vec +_make_downsample_fns(tup::NTuple{2}, block_repeats) = [tup for _ in 1:length(block_repeats)] # Stride for each block in the ResNet model -get_stride(idxs::NTuple{2, Int}) = (idxs[1] == 1 || idxs[2] != 1) ? 1 : 2 +function get_stride(stage_idx::Integer, block_idx::Integer) + return (stage_idx == 1 || block_idx != 1) ? 1 : 2 +end # returns `DropBlock`s for each stage of the ResNet as in timm. -# TODO - add experimental options for DropBlock as part of the API +# TODO - add experimental options for DropBlock as part of the API (#188) function _drop_blocks(drop_block_rate::AbstractFloat) return [ identity, identity, @@ -187,8 +201,7 @@ on how to use this function. shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - - `replace_pool`: Whether to replace the default 3x3 max pooling layer with a - 3x3 convolution with stride 2 and a normalisation layer. + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ @@ -219,104 +232,86 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - prenorm, - stride = 2, pad = 1, bias = false)...) : + prenorm, stride = 2, pad = 1, bias = false)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool), inplanes end -function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, - attn_fn = planes -> identity, kargs...) - return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor, - activation, norm_layer, prenorm, attn_fn) -end - -function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64, - reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, - attn_fn = planes -> identity, kargs...) - return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width, - reduction_factor, activation, - norm_layer, prenorm, attn_fn) -end - -function template_builder(downsample_fn::Union{typeof(downsample_conv), - typeof(downsample_pool), - typeof(downsample_identity)}; - norm_layer = BatchNorm, prenorm = false) - return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm) +function block_args(::typeof(basicblock), block_repeats; + downsample_vec, reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, + drop_path_rate = 0.0, drop_block_rate = 0.0, + attn_fn = planes -> identity) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + function get_layers(stage_idx, block_idx) + stride = get_stride(stage_idx, block_idx) + downsample_fns = downsample_vec[stage_idx] + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + return (; downsample_fns, reduction_factor, stride, activation, norm_layer, + prenorm, drop_path, drop_block, attn_fn) + end end -function configure_block(block_template, layers::Vector{Int}; expansion, - downsample_templates::Vector, inplanes::Integer = 64, - drop_path_rate = 0.0, drop_block_rate = 0.0, kargs...) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(layers)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(layers)) - # closure over `idxs` - function get_layers(idxs::NTuple{2, Int}) - stage_idx, block_idx = idxs - planes = 64 * 2^(stage_idx - 1) - # `get_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper. - stride = get_stride(idxs) - downsample_fns = downsample_templates[stage_idx] - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? - downsample_fns[1] : downsample_fns[2] - # DropBlock, DropPath both take in rates based on a linear scaling schedule - schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx +function block_args(::typeof(bottleneck), block_repeats; + downsample_vec, cardinality = 1, base_width = 64, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, + drop_block_rate = 0.0, drop_path_rate = 0.0, + attn_fn = planes -> identity) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + function get_layers(stage_idx, block_idx) + stride = get_stride(stage_idx, block_idx) + downsample_fns = downsample_vec[stage_idx] + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) - block = block_template(inplanes, planes; stride, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride) - # inplanes increases by expansion after each block - inplanes = (planes * expansion) - return ((block, downsample), inplanes) + return (; downsample_fns, reduction_factor, cardinality, base_width, stride, + activation, norm_layer, prenorm, drop_path, drop_block, attn_fn) end - return get_layers end # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function resnet_stages(get_layers, block_repeats::Vector{Int}, inplanes::Integer; - connection = addact, activation = relu, kwargs...) - outplanes = 0 +function resnet_stages(block_fn, block_repeats::Vector{<:Integer}, inplanes::Integer; + kwargs...) # Construct each stage stages = [] for (stage_idx, (num_blocks)) in enumerate(block_repeats) + planes = 64 * 2^(stage_idx - 1) + get_kwargs = block_args(block_fn, block_repeats; kwargs...) # Construct the blocks for each stage blocks = [] for block_idx in range(1, num_blocks) - layers, outplanes = get_layers((stage_idx, block_idx)) - block = Parallel(connection$activation, layers...) - push!(blocks, block) + push!(blocks, block_fn(inplanes, planes; get_kwargs(stage_idx, block_idx)...)) + inplanes = planes * expansion_factor(block_fn) end push!(stages, Chain(blocks...)) end - return Chain(stages...), outplanes + return Chain(stages...) end -function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; - inchannels::Integer = 3, nclasses::Integer = 1000, +function resnet(block_fn, block_repeats::Vector{<:Integer}, downsample_opt = :B; + imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64, - pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false, dropout_rate = 0.0, - kwargs...) + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = 0.0, + use_conv_classifier::Bool = false, nclasses::Integer = 1000, kwargs...) # Configure downsample templates - downsample_vec = _make_downsample_fns(downsample_opt, layers) - downsample_templates = map(x -> template_builder.(x), downsample_vec) - # Configure block templates - block_template = template_builder(block_fn; kwargs...) - get_layers = configure_block(block_template, layers; inplanes, - downsample_templates, - expansion = expansion_factor(block_fn), kwargs...) + downsample_vec = _make_downsample_fns(downsample_opt, block_repeats) # Build stages of the ResNet - stage_blocks, num_features = resnet_stages(get_layers, layers, inplanes; kwargs...) + stage_blocks = resnet_stages(block_fn, block_repeats, inplanes; downsample_vec, + kwargs...) + backbone = Chain(stem, stage_blocks) # Build the classifier head - classifier = create_classifier(num_features, nclasses; dropout_rate, pool_layer, - use_conv) - return Chain(Chain(stem, stage_blocks), classifier) + outfeatures = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true) + classifier = create_classifier(outfeatures[3], nclasses; dropout_rate, pool_layer, + use_conv = use_conv_classifier) + return Chain(backbone, classifier) end # block-layer configurations for ResNet-like models diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b433fbbc9..b4a882cff 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -27,7 +27,7 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. """ # TODO add experimental `DropBlock` options from timm such as gaussian noise and -# more precise `DropBlock` to deal with edges. +# more precise `DropBlock` to deal with edges (#188) function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale) where {T} H, W, _, _ = size(x) @@ -63,7 +63,7 @@ function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_sc @assert 0 ≤ drop_block_prob ≤ 1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" @assert 0 ≤ gamma_scale ≤ 1 - "gamma_scale must be between 0 and 1, got $gamma_scale" + return "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) From ca53acbfdc3f130c3a2ad7cd9ca058b31055c756 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 28 Jul 2022 08:43:06 +0530 Subject: [PATCH 54/64] Fix tests, hopefully --- .github/workflows/CI.yml | 13 ++++++------- test/convnets.jl | 2 +- test/mixers.jl | 16 ++++++++-------- test/runtests.jl | 12 +++--------- 4 files changed, 18 insertions(+), 25 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 32882857b..8fa607912 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,15 +27,14 @@ jobs: - x64 suite: - '["AlexNet", "VGG"]' - - '["GoogLeNet", "SqueezeNet"]' - - '["EfficientNet", "MobileNet"]' - - '[r"/*/ResNet*", "ResNeXt"]' - - 'r"/*/Inception/Inceptionv*"' - - '["InceptionResNetv2", "Xception"]' + - '["GoogLeNet", "SqueezeNet", "MobileNet"]' + - '["EfficientNet"]' + - '[r"ResNet", r"ResNeXt"]' + - '"Inception"' - '"DenseNet"' - '["ConvNeXt", "ConvMixer"]' - - '"ViT"' - - '"Mixers"' + - 'r"ViTs"' + - 'r"Mixers"' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/test/convnets.jl b/test/convnets.jl index 860ac422b..7f80d97f2 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -107,7 +107,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8] + @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.efficientnet_global_configs[name][1] x = rand(Float32, r, r, 3, 1) diff --git a/test/mixers.jl b/test/mixers.jl index 5e2e8cf27..429cd357b 100644 --- a/test/mixers.jl +++ b/test/mixers.jl @@ -10,14 +10,14 @@ end @testset "ResMLP" begin - @testset for mode in [:small, :base, :large, :huge] - @testset for drop_path_rate in [0.0, 0.5] - m = ResMLP(mode; drop_path_rate) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - end - end + @testset for mode in [:small, :base, :large, :huge] + @testset for drop_path_rate in [0.0, 0.5] + m = ResMLP(mode; drop_path_rate) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end + end end @testset "gMLP" begin diff --git a/test/runtests.jl b/test/runtests.jl index eaef97472..622bfc394 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,8 +29,8 @@ function gradtest(model, input) end function normalize_imagenet(data) - cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1)) - cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1)) + cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1)) + cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1)) return (data .- cmean) ./ cstd end @@ -38,7 +38,7 @@ end const TEST_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224)) # CHW -> WHC -const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,1)) |> normalize_imagenet +const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> normalize_imagenet # image net labels const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) @@ -58,17 +58,11 @@ x_256 = rand(Float32, 256, 256, 3, 1) include("convnets.jl") end -GC.safepoint() -GC.gc() - # Mixer tests @testset verbose = true "Mixers" begin include("mixers.jl") end -GC.safepoint() -GC.gc() - # ViT tests @testset verbose = true "ViTs" begin include("vits.jl") From 54ea5297c2255021f242510ba98b53dbebab2654 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 28 Jul 2022 19:18:20 +0530 Subject: [PATCH 55/64] Revert "Remove templating for now" This reverts commit 8c9f73f3947f220c40838ad4ca5e71d6af635789. --- src/convnets/resnets/core.jl | 183 ++++++++++++++++++----------------- src/layers/drop.jl | 4 +- 2 files changed, 96 insertions(+), 91 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index ebfcc7496..9af497844 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -15,9 +15,12 @@ Creates a basic ResNet block. - `downsample`: the downsampling function to use - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. + - `dilation`: the dilation of the second convolution. + - `first_dilation`: the dilation of the first convolution. - `activation`: the activation function to use. + - `connection`: the function applied to the output of residual and skip paths in + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -25,25 +28,19 @@ Creates a basic ResNet block. function and passed in. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ -function basicblock(inplanes::Integer, planes::Integer; downsample_fns, - stride::Integer = 1, reduction_factor::Integer = 1, - connection = addact, activation = relu, - norm_layer = BatchNorm, prenorm::Bool = false, +function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = expansion_factor(basicblock) first_planes = planes ÷ reduction_factor - outplanes = planes * expansion + outplanes = planes * expansion_factor(basicblock) conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm, stride, pad = 1, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm, pad = 1, bias = false) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] - downsample = downsample_block(downsample_fns, inplanes, planes, expansion; - stride, norm_layer, prenorm) - return Parallel(connection$activation, Chain(filter!(!=(identity), layers)...), - downsample) + return Chain(filter!(!=(identity), layers)...) end expansion_factor(::typeof(basicblock)) = 1 @@ -66,9 +63,11 @@ Creates a bottleneck ResNet block. - `base_width`: the number of output feature maps for each convolutional group. - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. + - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. + a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses + PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -76,28 +75,23 @@ Creates a bottleneck ResNet block. function and passed in. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ -function bottleneck(inplanes::Integer, planes::Integer; downsample_fns, stride::Integer = 1, - cardinality::Integer = 1, base_width::Integer = 64, - reduction_factor::Integer = 1, connection = addact, activation = relu, - norm_layer = BatchNorm, prenorm::Bool = false, +function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor - outplanes = planes * expansion + outplanes = planes * expansion_factor(bottleneck) conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm, stride, pad = 1, groups = cardinality, bias = false) conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, prenorm, bias = false) - downsample = downsample_block(downsample_fns, inplanes, planes, expansion; - stride, norm_layer, prenorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] - return Parallel(connection$activation, Chain(filter!(!=(identity), layers)...), - downsample) + return Chain(filter!(!=(identity), layers)...) end expansion_factor(::typeof(bottleneck)) = 4 @@ -132,12 +126,6 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) end end -function downsample_block(downsample_fns, inplanes, planes, expansion; stride, kwargs...) - down_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_fns[1] : - downsample_fns[2] - return down_fn(inplanes, planes * expansion; stride, kwargs...) -end - # Shortcut configurations for the ResNet models const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), @@ -148,7 +136,7 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), # specified as a `Vector` of `Symbol`s. This is used to make the downsample # `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is # already an `NTuple{2}` of functions, it is returned unchanged. -function _make_downsample_fns(vec::Vector{<:Symbol}, block_repeats) +function _make_downsample_fns(vec::Vector{<:Symbol}, layers) downs = [] for i in vec @assert i in keys(shortcut_dict) @@ -157,21 +145,19 @@ function _make_downsample_fns(vec::Vector{<:Symbol}, block_repeats) end return downs end -function _make_downsample_fns(sym::Symbol, block_repeats) +function _make_downsample_fns(sym::Symbol, layers) @assert sym in keys(shortcut_dict) "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - return collect(shortcut_dict[sym] for _ in 1:length(block_repeats)) + return collect(shortcut_dict[sym] for _ in 1:length(layers)) end -_make_downsample_fns(vec::Vector{<:NTuple{2}}, block_repeats) = vec -_make_downsample_fns(tup::NTuple{2}, block_repeats) = [tup for _ in 1:length(block_repeats)] +_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec +_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) # Stride for each block in the ResNet model -function get_stride(stage_idx::Integer, block_idx::Integer) - return (stage_idx == 1 || block_idx != 1) ? 1 : 2 -end +get_stride(idxs::NTuple{2, Int}) = (idxs[1] == 1 || idxs[2] != 1) ? 1 : 2 # returns `DropBlock`s for each stage of the ResNet as in timm. -# TODO - add experimental options for DropBlock as part of the API (#188) +# TODO - add experimental options for DropBlock as part of the API function _drop_blocks(drop_block_rate::AbstractFloat) return [ identity, identity, @@ -201,7 +187,8 @@ on how to use this function. shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. + - `replace_pool`: Whether to replace the default 3x3 max pooling layer with a + 3x3 convolution with stride 2 and a normalisation layer. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ @@ -232,86 +219,104 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - prenorm, stride = 2, pad = 1, bias = false)...) : + prenorm, + stride = 2, pad = 1, bias = false)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool), inplanes end -function block_args(::typeof(basicblock), block_repeats; - downsample_vec, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, - drop_path_rate = 0.0, drop_block_rate = 0.0, - attn_fn = planes -> identity) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - function get_layers(stage_idx, block_idx) - stride = get_stride(stage_idx, block_idx) - downsample_fns = downsample_vec[stage_idx] - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) - return (; downsample_fns, reduction_factor, stride, activation, norm_layer, - prenorm, drop_path, drop_block, attn_fn) - end +function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, + attn_fn = planes -> identity, kargs...) + return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor, + activation, norm_layer, prenorm, attn_fn) end -function block_args(::typeof(bottleneck), block_repeats; - downsample_vec, cardinality = 1, base_width = 64, - reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, - drop_block_rate = 0.0, drop_path_rate = 0.0, - attn_fn = planes -> identity) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) - function get_layers(stage_idx, block_idx) - stride = get_stride(stage_idx, block_idx) - downsample_fns = downsample_vec[stage_idx] - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx +function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, + attn_fn = planes -> identity, kargs...) + return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width, + reduction_factor, activation, + norm_layer, prenorm, attn_fn) +end + +function template_builder(downsample_fn::Union{typeof(downsample_conv), + typeof(downsample_pool), + typeof(downsample_identity)}; + norm_layer = BatchNorm, prenorm = false) + return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm) +end + +function configure_block(block_template, layers::Vector{Int}; expansion, + downsample_templates::Vector, inplanes::Integer = 64, + drop_path_rate = 0.0, drop_block_rate = 0.0, kargs...) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(layers)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(layers)) + # closure over `idxs` + function get_layers(idxs::NTuple{2, Int}) + stage_idx, block_idx = idxs + planes = 64 * 2^(stage_idx - 1) + # `get_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper. + stride = get_stride(idxs) + downsample_fns = downsample_templates[stage_idx] + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_fns[1] : downsample_fns[2] + # DropBlock, DropPath both take in rates based on a linear scaling schedule + schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) - return (; downsample_fns, reduction_factor, cardinality, base_width, stride, - activation, norm_layer, prenorm, drop_path, drop_block, attn_fn) + block = block_template(inplanes, planes; stride, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride) + # inplanes increases by expansion after each block + inplanes = (planes * expansion) + return ((block, downsample), inplanes) end + return get_layers end # Makes the main stages of the ResNet model. This is an internal function and should not be # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function resnet_stages(block_fn, block_repeats::Vector{<:Integer}, inplanes::Integer; - kwargs...) +function resnet_stages(get_layers, block_repeats::Vector{Int}, inplanes::Integer; + connection = addact, activation = relu, kwargs...) + outplanes = 0 # Construct each stage stages = [] for (stage_idx, (num_blocks)) in enumerate(block_repeats) - planes = 64 * 2^(stage_idx - 1) - get_kwargs = block_args(block_fn, block_repeats; kwargs...) # Construct the blocks for each stage blocks = [] for block_idx in range(1, num_blocks) - push!(blocks, block_fn(inplanes, planes; get_kwargs(stage_idx, block_idx)...)) - inplanes = planes * expansion_factor(block_fn) + layers, outplanes = get_layers((stage_idx, block_idx)) + block = Parallel(connection$activation, layers...) + push!(blocks, block) end push!(stages, Chain(blocks...)) end - return Chain(stages...) + return Chain(stages...), outplanes end -function resnet(block_fn, block_repeats::Vector{<:Integer}, downsample_opt = :B; - imsize::Dims{2} = (256, 256), inchannels::Integer = 3, +function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; + inchannels::Integer = 3, nclasses::Integer = 1000, stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64, - pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = 0.0, - use_conv_classifier::Bool = false, nclasses::Integer = 1000, kwargs...) + pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false, dropout_rate = 0.0, + kwargs...) # Configure downsample templates - downsample_vec = _make_downsample_fns(downsample_opt, block_repeats) + downsample_vec = _make_downsample_fns(downsample_opt, layers) + downsample_templates = map(x -> template_builder.(x), downsample_vec) + # Configure block templates + block_template = template_builder(block_fn; kwargs...) + get_layers = configure_block(block_template, layers; inplanes, + downsample_templates, + expansion = expansion_factor(block_fn), kwargs...) # Build stages of the ResNet - stage_blocks = resnet_stages(block_fn, block_repeats, inplanes; downsample_vec, - kwargs...) - backbone = Chain(stem, stage_blocks) + stage_blocks, num_features = resnet_stages(get_layers, layers, inplanes; kwargs...) # Build the classifier head - outfeatures = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true) - classifier = create_classifier(outfeatures[3], nclasses; dropout_rate, pool_layer, - use_conv = use_conv_classifier) - return Chain(backbone, classifier) + classifier = create_classifier(num_features, nclasses; dropout_rate, pool_layer, + use_conv) + return Chain(Chain(stem, stage_blocks), classifier) end # block-layer configurations for ResNet-like models diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b4a882cff..b433fbbc9 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -27,7 +27,7 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. """ # TODO add experimental `DropBlock` options from timm such as gaussian noise and -# more precise `DropBlock` to deal with edges (#188) +# more precise `DropBlock` to deal with edges. function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale) where {T} H, W, _, _ = size(x) @@ -63,7 +63,7 @@ function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_sc @assert 0 ≤ drop_block_prob ≤ 1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" @assert 0 ≤ gamma_scale ≤ 1 - return "gamma_scale must be between 0 and 1, got $gamma_scale" + "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) From cff07cb8a6c0d1c87dd8281b7650cf4a68a67992 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 07:02:11 +0530 Subject: [PATCH 56/64] MobileNet tweaks --- src/convnets/mobilenet.jl | 327 -------------------------- src/convnets/mobilenet/mobilenetv1.jl | 11 +- 2 files changed, 3 insertions(+), 335 deletions(-) delete mode 100644 src/convnets/mobilenet.jl diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl deleted file mode 100644 index 8de98e886..000000000 --- a/src/convnets/mobilenet.jl +++ /dev/null @@ -1,327 +0,0 @@ -# MobileNetv1 - -""" - mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - nclasses = 1000) - -Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `dw`: Set true to use a depthwise separable convolution or false for regular convolution - + `o`: The number of output feature maps - + `s`: The stride of the convolutional kernel - + `r`: The number of time this configuration block is repeated - - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. The default value is 3. - - `nclasses`: The number of output classes -""" -function mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - nclasses = 1000) - layers = [] - for (dw, outch, stride, nrepeats) in config - outch = Int(outch * width_mult) - for _ in 1:nrepeats - layer = dw ? - depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1, - bias = false) - append!(layers, layer) - inchannels = outch - end - end - - return Chain(Chain(layers), - Chain(GlobalMeanPool(), - MLUtils.flatten, - Dense(inchannels, nclasses))) -end - -const mobilenetv1_configs = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), -] - -""" - MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv1 model with the baseline configuration -([reference](https://arxiv.org/abs/1704.04861v1)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv1`](#). -""" -struct MobileNetv1 - layers::Any -end - -function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv1")) - return MobileNetv1(layers) -end - -@functor MobileNetv1 - -(m::MobileNetv1)(x) = m.layers(x) - -backbone(m::MobileNetv1) = m.layers[1] -classifier(m::MobileNetv1) = m.layers[2] - -# MobileNetv2 - -""" - mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - -Create a MobileNetv2 model. -([reference](https://arxiv.org/abs/1801.04381)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer - + `c`: The number of output feature maps - + `n`: The number of times a block is repeated - + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. The default value is 3. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: The number of output classes -""" -function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - # building first layer - inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) - layers = [] - append!(layers, conv_bn((3, 3), inchannels, inplanes; pad = 1, stride = 2)) - # building inverted residual blocks - for (t, c, n, s, a) in configs - outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) - for i in 1:n - push!(layers, - invertedresidual(3, inplanes, inplanes * t, outplanes, a; - stride = i == 1 ? s : 1)) - inplanes = outplanes - end - end - # building last several layers - outplanes = (width_mult > 1) ? - _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : - max_width - return Chain(Chain(Chain(layers), - conv_bn((1, 1), inplanes, outplanes, relu6; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(outplanes, nclasses))) -end - -# Layer configurations for MobileNetv2 -const mobilenetv2_configs = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), -] - -# Model definition for MobileNetv2 -struct MobileNetv2 - layers::Any -end - -""" - MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv2 model with the specified configuration. -([reference](https://arxiv.org/abs/1801.04381)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv2`](#). -""" -function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv2")) - return MobileNetv2(layers) -end - -@functor MobileNetv2 - -(m::MobileNetv2)(x) = m.layers(x) - -backbone(m::MobileNetv2) = m.layers[1] -classifier(m::MobileNetv2) = m.layers[2] - -# MobileNetv3 - -""" - mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - -Create a MobileNetv3 model. -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - - `configs`: a "list of tuples" configuration for each layer that details: - - + `k::Integer` - The size of the convolutional kernel - + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - + `t::Integer` - The number of output feature maps for a given block - + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - + `s::Integer` - The stride of the convolutional kernel - + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. The default value is 3. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: the number of output classes -""" -function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - # building first layer - inplanes = _round_channels(16 * width_mult, 8) - layers = [] - append!(layers, - conv_bn((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, - bias = false)) - explanes = 0 - # building inverted residual blocks - for (k, t, c, r, a, s) in configs - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - invertedresidual(k, inplanes, explanes, outplanes, a; - stride = s, reduction = r)) - inplanes = outplanes - end - # building last several layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel - classifier = Chain(Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses)) - return Chain(Chain(Chain(layers), - conv_bn((1, 1), inplanes, explanes, hardswish; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) -end - -# Configurations for small and large mode for MobileNetv3 -mobilenetv3_configs = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) - -# Model definition for MobileNetv3 -struct MobileNetv3 - layers::Any -end - -""" - MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv3 model with the specified configuration. -([reference](https://arxiv.org/abs/1905.02244)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -# Arguments - - - `mode`: :small or :large for the size of the model (see paper). - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of channels in the input. The default value is 3. - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -See also [`Metalhead.mobilenetv3`](#). -""" -function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, - pretrain = false, nclasses = 1000) - @assert mode in [:large, :small] "`mode` has to be either :large or :small" - max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, - nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv3", mode)) - return MobileNetv3(layers) -end - -@functor MobileNetv3 - -(m::MobileNetv3)(x) = m.layers(x) - -backbone(m::MobileNetv3) = m.layers[1] -classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index 4add739b4..54237446e 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -2,7 +2,6 @@ mobilenetv1(width_mult, config; activation = relu, inchannels = 3, - fcsize = 1024, nclasses = 1000) Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). @@ -19,14 +18,12 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. - - `fcsize`: The intermediate fully-connected size between the convolution and final layers + - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ function mobilenetv1(width_mult, config; activation = relu, inchannels = 3, - fcsize = 1024, nclasses = 1000) layers = [] for (dw, outch, stride, nrepeats) in config @@ -35,8 +32,7 @@ function mobilenetv1(width_mult, config; layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1, bias = false) : - conv_norm((3, 3), inchannels, outch, activation; stride = stride, - pad = 1, + conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, bias = false) append!(layers, layer) inchannels = outch @@ -46,8 +42,7 @@ function mobilenetv1(width_mult, config; return Chain(Chain(layers), Chain(GlobalMeanPool(), MLUtils.flatten, - Dense(inchannels, fcsize, activation), - Dense(fcsize, nclasses))) + Dense(inchannels, nclasses))) end const mobilenetv1_configs = [ From 674b27e79a55d46b6dbe2ef184012c53239e4961 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 08:51:48 +0530 Subject: [PATCH 57/64] Make templating work again And expand the lowest level of the ResNet API --- src/convnets/resnets/core.jl | 128 +++++++++++++++++------------------ src/layers/drop.jl | 4 +- 2 files changed, 64 insertions(+), 68 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 9af497844..533decf21 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -132,32 +132,13 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :C => (downsample_conv, downsample_conv), :D => (downsample_pool, downsample_identity)) -# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is -# specified as a `Vector` of `Symbol`s. This is used to make the downsample -# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is -# already an `NTuple{2}` of functions, it is returned unchanged. -function _make_downsample_fns(vec::Vector{<:Symbol}, layers) - downs = [] - for i in vec - @assert i in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - push!(downs, shortcut_dict[i]) - end - return downs -end -function _make_downsample_fns(sym::Symbol, layers) - @assert sym in keys(shortcut_dict) - "The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))" - return collect(shortcut_dict[sym] for _ in 1:length(layers)) -end -_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec -_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers)) - # Stride for each block in the ResNet model -get_stride(idxs::NTuple{2, Int}) = (idxs[1] == 1 || idxs[2] != 1) ? 1 : 2 +function get_stride(block_idx::Integer, stage_idx::Integer) + return (stage_idx == 1 || block_idx != 1) ? 1 : 2 +end # returns `DropBlock`s for each stage of the ResNet as in timm. -# TODO - add experimental options for DropBlock as part of the API +# TODO - add experimental options for DropBlock as part of the API (#188) function _drop_blocks(drop_block_rate::AbstractFloat) return [ identity, identity, @@ -225,16 +206,24 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, return Chain(conv1, bn1, stempool), inplanes end -function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, +# Templating builders for the blocks and the downsampling layers +function template_builder(block_fn; kwargs...) + function (inplanes, planes; _kwargs...) + return block_fn(inplanes, planes; kwargs..., _kwargs...) + end +end + +function template_builder(::typeof(basicblock); reduction_factor::Integer = 1, + activation = relu, norm_layer = BatchNorm, prenorm::Bool = false, attn_fn = planes -> identity, kargs...) return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor, activation, norm_layer, prenorm, attn_fn) end -function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64, - reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, +function template_builder(::typeof(bottleneck); cardinality::Integer = 1, + base_width::Integer = 64, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, prenorm::Bool = false, attn_fn = planes -> identity, kargs...) return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width, reduction_factor, activation, @@ -248,30 +237,32 @@ function template_builder(downsample_fn::Union{typeof(downsample_conv), return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm) end -function configure_block(block_template, layers::Vector{Int}; expansion, - downsample_templates::Vector, inplanes::Integer = 64, - drop_path_rate = 0.0, drop_block_rate = 0.0, kargs...) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(layers)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(layers)) +resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1) + +function configure_resnet_block(block_template, expansion, block_repeats::Vector{<:Integer}; + stride_fn = get_stride, plane_fn = resnet_planes, + downsample_templates::NTuple{2, Any}, + inplanes::Integer = 64, + drop_path_rate = 0.0, drop_block_rate = 0.0, kwargs...) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) # closure over `idxs` - function get_layers(idxs::NTuple{2, Int}) - stage_idx, block_idx = idxs - planes = 64 * 2^(stage_idx - 1) + function get_layers(stage_idx::Integer, block_idx::Integer) + planes = plane_fn(stage_idx) # `get_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper. - stride = get_stride(idxs) - downsample_fns = downsample_templates[stage_idx] - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? - downsample_fns[1] : downsample_fns[2] + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_template = (stride != 1 || inplanes != planes * expansion) ? + downsample_templates[1] : downsample_templates[2] # DropBlock, DropPath both take in rates based on a linear scaling schedule - schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) block = block_template(inplanes, planes; stride, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride) + downsample = downsample_template(inplanes, planes * expansion; stride) # inplanes increases by expansion after each block inplanes = (planes * expansion) - return ((block, downsample), inplanes) + return block, downsample end return get_layers end @@ -280,43 +271,48 @@ end # used by end-users. `block_fn` is a function that returns a single block of the ResNet. # See `basicblock` and `bottleneck` for examples. A block must define a function # `expansion(::typeof(block))` that returns the expansion factor of the block. -function resnet_stages(get_layers, block_repeats::Vector{Int}, inplanes::Integer; - connection = addact, activation = relu, kwargs...) - outplanes = 0 +function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) # Construct each stage stages = [] for (stage_idx, (num_blocks)) in enumerate(block_repeats) # Construct the blocks for each stage - blocks = [] - for block_idx in range(1, num_blocks) - layers, outplanes = get_layers((stage_idx, block_idx)) - block = Parallel(connection$activation, layers...) - push!(blocks, block) - end + blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...) + for block_idx in range(1, num_blocks)] push!(stages, Chain(blocks...)) end - return Chain(stages...), outplanes + return Chain(stages...) end -function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B; - inchannels::Integer = 3, nclasses::Integer = 1000, +function resnet(connection, get_layers, block_repeats::Vector{<:Integer}, stem, classifier) + stage_blocks = resnet_stages(get_layers, block_repeats, connection) + return Chain(Chain(stem, stage_blocks), classifier) +end + +function resnet(block_fn, block_repeats::Vector{<:Integer}, + downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); + imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64, - pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false, dropout_rate = 0.0, - kwargs...) + connection = addact, activation = relu, + pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false, + dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) # Configure downsample templates - downsample_vec = _make_downsample_fns(downsample_opt, layers) - downsample_templates = map(x -> template_builder.(x), downsample_vec) + downsample_templates = map(template_builder, downsample_opt) # Configure block templates block_template = template_builder(block_fn; kwargs...) - get_layers = configure_block(block_template, layers; inplanes, - downsample_templates, - expansion = expansion_factor(block_fn), kwargs...) + get_layers = configure_resnet_block(block_template, expansion_factor(block_fn), + block_repeats; inplanes, downsample_templates, + kwargs...) # Build stages of the ResNet - stage_blocks, num_features = resnet_stages(get_layers, layers, inplanes; kwargs...) + stage_blocks = resnet_stages(get_layers, block_repeats, connection$activation) + backbone = Chain(stem, stage_blocks) # Build the classifier head - classifier = create_classifier(num_features, nclasses; dropout_rate, pool_layer, + nfeaturemaps = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true)[3] + classifier = create_classifier(nfeaturemaps, nclasses; dropout_rate, pool_layer, use_conv) - return Chain(Chain(stem, stage_blocks), classifier) + return Chain(backbone, classifier) +end +function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) + return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs...) end # block-layer configurations for ResNet-like models diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b433fbbc9..b4a882cff 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -27,7 +27,7 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. """ # TODO add experimental `DropBlock` options from timm such as gaussian noise and -# more precise `DropBlock` to deal with edges. +# more precise `DropBlock` to deal with edges (#188) function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale) where {T} H, W, _, _ = size(x) @@ -63,7 +63,7 @@ function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_sc @assert 0 ≤ drop_block_prob ≤ 1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" @assert 0 ≤ gamma_scale ≤ 1 - "gamma_scale must be between 0 and 1, got $gamma_scale" + return "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) From aa2a9ef8a62b2d1eb43f5c34e0f435255a054a14 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 09:21:44 +0530 Subject: [PATCH 58/64] Tests just don't fix themselves --- test/convnets.jl | 2 +- test/mixers.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/convnets.jl b/test/convnets.jl index 7f80d97f2..8e4533e9a 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -107,7 +107,7 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] + @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] # preferred image resolution scaling r = Metalhead.efficientnet_global_configs[name][1] x = rand(Float32, r, r, 3, 1) diff --git a/test/mixers.jl b/test/mixers.jl index 429cd357b..885ff5838 100644 --- a/test/mixers.jl +++ b/test/mixers.jl @@ -1,5 +1,5 @@ @testset "MLPMixer" begin - @testset for mode in [:small, :base, :large, :huge] + @testset for mode in [:small, :base, :large] #:huge] @testset for drop_path_rate in [0.0, 0.5] m = MLPMixer(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -10,7 +10,7 @@ end @testset "ResMLP" begin - @testset for mode in [:small, :base, :large, :huge] + @testset for mode in [:small, :base, :large] #:huge] @testset for drop_path_rate in [0.0, 0.5] m = ResMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -21,7 +21,7 @@ end end @testset "gMLP" begin - @testset for mode in [:small, :base, :large, :huge] + @testset for mode in [:small, :base, :large] #:huge] @testset for drop_path_rate in [0.0, 0.5] m = gMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) From b143b950e6634e28b62a5638b1ff671acd8f8fa5 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 20:14:51 +0530 Subject: [PATCH 59/64] Fifth refactor is a charm Also, we aren't using the skips anymore --- src/convnets/resnets/core.jl | 212 +++++++++++++++++++---------------- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 45 -------- test/convnets.jl | 2 +- 4 files changed, 115 insertions(+), 146 deletions(-) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 533decf21..ad597efbf 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,9 +1,8 @@ """ - basicblock(inplanes, planes; stride = 1, downsample = identity, - reduction_factor = 1, dilation = 1, first_dilation = dilation, - activation = relu, connection = addact\$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity) + basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) Creates a basic ResNet block. @@ -12,15 +11,9 @@ Creates a basic ResNet block. - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `downsample`: the downsampling function to use - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `dilation`: the dilation of the second convolution. - - `first_dilation`: the dilation of the first convolution. - `activation`: the activation function to use. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -28,8 +21,9 @@ Creates a basic ResNet block. function and passed in. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ -function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, +function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, prenorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor @@ -45,11 +39,11 @@ end expansion_factor(::typeof(basicblock)) = 1 """ - bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, reduction_factor = 1, first_dilation = 1, - activation = relu, connection = addact\$activation, - norm_layer = BatchNorm, drop_block = identity, drop_path = identity, - attn_fn = planes -> identity) + bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, prenorm = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) Creates a bottleneck ResNet block. @@ -58,16 +52,11 @@ Creates a bottleneck ResNet block. - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `downsample`: the downsampling function to use - `cardinality`: the number of groups in the convolution. - `base_width`: the number of output feature maps for each convolutional group. - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first convolution. - - `first_dilation`: the dilation of the 3x3 convolution. - `activation`: the activation function to use. - - `connection`: the function applied to the output of residual and skip paths in - a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses - PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`. - `norm_layer`: the normalization layer to use. - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` function and passed in. @@ -75,9 +64,10 @@ Creates a bottleneck ResNet block. function and passed in. - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ -function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, - reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, +function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, + cardinality::Integer = 1, base_width::Integer = 64, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, prenorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) width = floor(Int, planes * (base_width / 64)) * cardinality @@ -113,6 +103,7 @@ end # Downsample layer which is an identity projection. Uses max pooling # when the output size is more than the input size. +# TODO - figure out how to make this work when outplanes < inplanes function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) if outplanes > inplanes return Chain(MaxPool((1, 1); stride = 2), @@ -174,8 +165,8 @@ on how to use this function. - `activation`: The activation function used in the stem. """ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, - replace_pool::Bool = false, norm_layer = BatchNorm, prenorm = false, - activation = relu) + replace_pool::Bool = false, activation = relu, + norm_layer = BatchNorm, prenorm::Bool = false) @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem @@ -203,65 +194,70 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, prenorm, stride = 2, pad = 1, bias = false)...) : MaxPool((3, 3); stride = 2, pad = 1) - return Chain(conv1, bn1, stempool), inplanes -end - -# Templating builders for the blocks and the downsampling layers -function template_builder(block_fn; kwargs...) - function (inplanes, planes; _kwargs...) - return block_fn(inplanes, planes; kwargs..., _kwargs...) - end -end - -function template_builder(::typeof(basicblock); reduction_factor::Integer = 1, - activation = relu, norm_layer = BatchNorm, prenorm::Bool = false, - attn_fn = planes -> identity, kargs...) - return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor, - activation, norm_layer, prenorm, attn_fn) + return Chain(conv1, bn1, stempool) end -function template_builder(::typeof(bottleneck); cardinality::Integer = 1, - base_width::Integer = 64, - reduction_factor::Integer = 1, activation = relu, - norm_layer = BatchNorm, prenorm::Bool = false, - attn_fn = planes -> identity, kargs...) - return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width, - reduction_factor, activation, - norm_layer, prenorm, attn_fn) -end +resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1) -function template_builder(downsample_fn::Union{typeof(downsample_conv), - typeof(downsample_pool), - typeof(downsample_identity)}; - norm_layer = BatchNorm, prenorm = false) - return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm) +function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, + reduction_factor::Integer = 1, expansion::Integer = 1, + norm_layer = BatchNorm, prenorm::Bool = false, + activation = relu, attn_fn = planes -> identity, + drop_block_rate = 0.0, drop_path_rate = 0.0, + stride_fn = get_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + planes = planes_fn(stage_idx) + # `get_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_tuple[1] : downsample_tuple[2] + # DropBlock, DropPath both take in rates based on a linear scaling schedule + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = basicblock(inplanes, planes; stride, reduction_factor, activation, + norm_layer, prenorm, attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride) + # inplanes increases by expansion after each block + inplanes = planes * expansion + return block, downsample + end + return get_layers end -resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1) - -function configure_resnet_block(block_template, expansion, block_repeats::Vector{<:Integer}; - stride_fn = get_stride, plane_fn = resnet_planes, - downsample_templates::NTuple{2, Any}, - inplanes::Integer = 64, - drop_path_rate = 0.0, drop_block_rate = 0.0, kwargs...) +function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, + cardinality::Integer = 1, base_width::Integer = 64, + reduction_factor::Integer = 1, expansion::Integer = 4, + norm_layer = BatchNorm, prenorm::Bool = false, + activation = relu, attn_fn = planes -> identity, + drop_block_rate = 0.0, drop_path_rate = 0.0, + stride_fn = get_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) - planes = plane_fn(stage_idx) + planes = planes_fn(stage_idx) # `get_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) - downsample_template = (stride != 1 || inplanes != planes * expansion) ? - downsample_templates[1] : downsample_templates[2] + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_tuple[1] : downsample_tuple[2] # DropBlock, DropPath both take in rates based on a linear scaling schedule schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) - block = block_template(inplanes, planes; stride, drop_path, drop_block) - downsample = downsample_template(inplanes, planes * expansion; stride) + block = bottleneck(inplanes, planes; stride, cardinality, base_width, + reduction_factor, activation, norm_layer, prenorm, + attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride) # inplanes increases by expansion after each block - inplanes = (planes * expansion) + inplanes = planes * expansion return block, downsample end return get_layers @@ -283,41 +279,59 @@ function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) return Chain(stages...) end -function resnet(connection, get_layers, block_repeats::Vector{<:Integer}, stem, classifier) - stage_blocks = resnet_stages(get_layers, block_repeats, connection) - return Chain(Chain(stem, stage_blocks), classifier) +function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; + downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity), + cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, + reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), + inchannels::Integer = 3, stem_fn = resnet_stem, + connection = addact, activation = relu, norm_layer = BatchNorm, + prenorm::Bool = false, attn_fn = planes -> identity, + pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false, + drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0, + nclasses::Integer = 1000) + # Build stem + stem = stem_fn(; inchannels) + # Block builder + if block_type == :basicblock + @assert cardinality==1 "Cardinality must be 1 for `basicblock`" + @assert base_width==64 "Base width must be 64 for `basicblock`" + get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, prenorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = get_stride, planes_fn = resnet_planes, + downsample_tuple = downsample_opt) + elseif block_type == :bottleneck + get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, + reduction_factor, activation, norm_layer, + prenorm, attn_fn, drop_block_rate, drop_path_rate, + stride_fn = get_stride, planes_fn = resnet_planes, + downsample_tuple = downsample_opt) + else + throw(ArgumentError("Unknown block type $block_type")) + end + classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, + pool_layer, use_conv) + return resnet((imsize..., inchannels), stem, connection$activation, get_layers, + block_repeats, classifier_fn) +end +function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) + return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...) end -function resnet(block_fn, block_repeats::Vector{<:Integer}, - downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); - imsize::Dims{2} = (256, 256), inchannels::Integer = 3, - stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64, - connection = addact, activation = relu, - pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false, - dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) - # Configure downsample templates - downsample_templates = map(template_builder, downsample_opt) - # Configure block templates - block_template = template_builder(block_fn; kwargs...) - get_layers = configure_resnet_block(block_template, expansion_factor(block_fn), - block_repeats; inplanes, downsample_templates, - kwargs...) +function resnet(img_dims, stem, connection, get_layers, block_repeats::Vector{<:Integer}, + classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection$activation) + stage_blocks = resnet_stages(get_layers, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Build the classifier head - nfeaturemaps = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true)[3] - classifier = create_classifier(nfeaturemaps, nclasses; dropout_rate, pool_layer, - use_conv) + nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] + classifier = classifier_fn(nfeaturemaps) return Chain(backbone, classifier) end -function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) - return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs...) -end # block-layer configurations for ResNet-like models -const resnet_configs = Dict(18 => (basicblock, [2, 2, 2, 2]), - 34 => (basicblock, [3, 4, 6, 3]), - 50 => (bottleneck, [3, 4, 6, 3]), - 101 => (bottleneck, [3, 4, 23, 3]), - 152 => (bottleneck, [3, 8, 36, 3])) +const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]), + 34 => (:basicblock, [3, 4, 6, 3]), + 50 => (:bottleneck, [3, 4, 6, 3]), + 101 => (:bottleneck, [3, 4, 23, 3]), + 152 => (:bottleneck, [3, 8, 36, 3])) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 9b5aa588b..2bc996278 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -26,7 +26,7 @@ include("normalise.jl") export prenorm, ChannelLayerNorm include("conv.jl") -export conv_norm, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection +export conv_norm, depthwise_sep_conv_bn, invertedresidual include("drop.jl") export DropBlock, DropPath, droppath_rates diff --git a/src/layers/conv.jl b/src/layers/conv.jl index b28507d6c..6f8a31d47 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -96,51 +96,6 @@ function depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = re use_bn = use_bn[2])) end -""" - skip_projection(inplanes, outplanes, downsample = false) - -Create a skip projection -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `inplanes`: number of input feature maps - - `outplanes`: number of output feature maps - - `downsample`: set to `true` to downsample the input -""" -function skip_projection(inplanes, outplanes, downsample = false) - return downsample ? - Chain(conv_norm((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)) : - Chain(conv_norm((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)) -end - -# array -> PaddedView(0, array, outplanes) for zero padding arrays -""" - skip_identity(inplanes, outplanes[, downsample]) - -Create a identity projection -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps - - `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#). -""" -function skip_identity(inplanes, outplanes) - if outplanes > inplanes - return Chain(MaxPool((1, 1); stride = 2), - y -> cat_channels(y, - zeros(eltype(y), - size(y, 1), - size(y, 2), - outplanes - inplanes, size(y, 4)))) - else - return identity - end -end -skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) - """ invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; stride, reduction = nothing) diff --git a/test/convnets.jl b/test/convnets.jl index 8e4533e9a..1b44676bc 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -32,7 +32,7 @@ end # end @testset "resnet" begin - @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] + @testset for block_fn in [:basicblock, :bottleneck] layer_list = [ [2, 2, 2, 2], [3, 4, 6, 3], From fc74aa1169fbfe3581d87c5430ffbd8928a41c70 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 22:14:30 +0530 Subject: [PATCH 60/64] Cleanup - docs and code Co-Authored-By: Kyle Daruwalla --- src/convnets/densenet.jl | 6 +- src/convnets/inception/xception.jl | 6 +- src/convnets/mobilenet/mobilenetv1.jl | 2 +- src/convnets/resnets/core.jl | 124 ++++++++++++-------------- src/convnets/resnets/seresnet.jl | 4 +- src/layers/Layers.jl | 2 +- src/layers/conv.jl | 36 ++++---- src/layers/pool.jl | 3 +- test/convnets.jl | 6 +- 9 files changed, 93 insertions(+), 96 deletions(-) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 0c5bd6ad6..c41b4028b 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -13,9 +13,9 @@ Create a Densenet bottleneck layer function dense_bottleneck(inplanes, outplanes) inner_channels = 4 * outplanes return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, - prenorm = true)..., + revnorm = true)..., conv_norm((3, 3), inner_channels, outplanes; pad = 1, - bias = false, prenorm = true)...), + bias = false, revnorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes, outplanes) - return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, prenorm = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)..., MeanPool((2, 2))) end diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index 6f4928385..a585aadd4 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -34,7 +34,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end push!(layers, relu) append!(layers, - depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, + depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, use_bn = (false, false))) push!(layers, BatchNorm(outc)) end @@ -63,8 +63,8 @@ function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) xception_block(256, 728, 2; stride = 2), [xception_block(728, 728, 3) for _ in 1:8]..., xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) + depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)..., + depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...) head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), Dense(2048, nclasses)) return Chain(body, head) diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index 54237446e..22beaf86f 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -30,7 +30,7 @@ function mobilenetv1(width_mult, config; outch = Int(outch * width_mult) for _ in 1:nrepeats layer = dw ? - depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; + depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; stride = stride, pad = 1, bias = false) : conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, bias = false) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index ad597efbf..e42b81c70 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,51 +1,48 @@ """ basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, + norm_layer = BatchNorm, revnorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) -Creates a basic ResNet block. +Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). # Arguments - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first - convolution. + - `reduction_factor`: the factor by which the input feature maps + are reduced before the first convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. - - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` - function and passed in. + - `drop_block`: the drop block layer + - `drop_path`: the drop path layer - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, reduction_factor::Integer = 1, activation = relu, - norm_layer = BatchNorm, prenorm::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor outplanes = planes * expansion_factor(basicblock) - conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm, + conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, stride, pad = 1, bias = false) - conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm, + conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, pad = 1, bias = false) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) end -expansion_factor(::typeof(basicblock)) = 1 """ bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, prenorm = false, + norm_layer = BatchNorm, revnorm = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) -Creates a bottleneck ResNet block. +Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). # Arguments @@ -58,46 +55,43 @@ Creates a bottleneck ResNet block. convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. - - `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks` - function and passed in. - - `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks` - function and passed in. + - `drop_block`: the drop block layer + - `drop_path`: the drop path layer - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. """ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, cardinality::Integer = 1, base_width::Integer = 64, reduction_factor::Integer = 1, activation = relu, - norm_layer = BatchNorm, prenorm::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * expansion_factor(bottleneck) - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm, + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, bias = false) - conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm, + conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, stride, pad = 1, groups = cardinality, bias = false) - conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, prenorm, + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm, bias = false) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) end -expansion_factor(::typeof(bottleneck)) = 4 # Downsample layer using convolutions. function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, prenorm = false) - return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, prenorm, + norm_layer = BatchNorm, revnorm = false) + return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, pad = SamePad(), stride, bias = false)...) end # Downsample layer using max pooling function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, - norm_layer = BatchNorm, prenorm = false) + norm_layer = BatchNorm, revnorm = false) pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, prenorm, + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, bias = false)...) end @@ -124,18 +118,18 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), :D => (downsample_pool, downsample_identity)) # Stride for each block in the ResNet model -function get_stride(block_idx::Integer, stage_idx::Integer) +function resnet_stride(stage_idx::Integer, block_idx::Integer) return (stage_idx == 1 || block_idx != 1) ? 1 : 2 end # returns `DropBlock`s for each stage of the ResNet as in timm. # TODO - add experimental options for DropBlock as part of the API (#188) -function _drop_blocks(drop_block_rate::AbstractFloat) - return [ - identity, identity, - DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00), - ] -end +# function _drop_blocks(drop_block_rate::AbstractFloat) +# return [ +# identity, identity, +# DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00), +# ] +# end """ resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, @@ -159,14 +153,13 @@ on how to use this function. shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - - `replace_pool`: Whether to replace the default 3x3 max pooling layer with a - 3x3 convolution with stride 2 and a normalisation layer. + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, replace_pool::Bool = false, activation = relu, - norm_layer = BatchNorm, prenorm::Bool = false) + norm_layer = BatchNorm, revnorm::Bool = false) @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Main stem @@ -180,7 +173,7 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, stem_channels = (3 * (stem_width ÷ 4), stem_width) end conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; - norm_layer, prenorm, stride = 2, pad = 1, bias = false)..., + norm_layer, revnorm, stride = 2, pad = 1, bias = false)..., conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; norm_layer, pad = 1, bias = false)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) @@ -191,7 +184,7 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, # Stem pooling stempool = replace_pool ? Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, - prenorm, + revnorm, stride = 2, pad = 1, bias = false)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool) @@ -201,17 +194,17 @@ resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1) function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, reduction_factor::Integer = 1, expansion::Integer = 1, - norm_layer = BatchNorm, prenorm::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, drop_block_rate = 0.0, drop_path_rate = 0.0, - stride_fn = get_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) planes = planes_fn(stage_idx) - # `get_stride` is a callback that the user can tweak to change the stride of the + # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) downsample_fn = (stride != 1 || inplanes != planes * expansion) ? @@ -221,7 +214,7 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) block = basicblock(inplanes, planes; stride, reduction_factor, activation, - norm_layer, prenorm, attn_fn, drop_path, drop_block) + norm_layer, revnorm, attn_fn, drop_path, drop_block) downsample = downsample_fn(inplanes, planes * expansion; stride) # inplanes increases by expansion after each block inplanes = planes * expansion @@ -233,17 +226,17 @@ end function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, cardinality::Integer = 1, base_width::Integer = 64, reduction_factor::Integer = 1, expansion::Integer = 4, - norm_layer = BatchNorm, prenorm::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, drop_block_rate = 0.0, drop_path_rate = 0.0, - stride_fn = get_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) planes = planes_fn(stage_idx) - # `get_stride` is a callback that the user can tweak to change the stride of the + # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) downsample_fn = (stride != 1 || inplanes != planes * expansion) ? @@ -253,7 +246,7 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) block = bottleneck(inplanes, planes; stride, cardinality, base_width, - reduction_factor, activation, norm_layer, prenorm, + reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) downsample = downsample_fn(inplanes, planes * expansion; stride) # inplanes increases by expansion after each block @@ -270,22 +263,33 @@ end function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) # Construct each stage stages = [] - for (stage_idx, (num_blocks)) in enumerate(block_repeats) + for (stage_idx, num_blocks) in enumerate(block_repeats) # Construct the blocks for each stage blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...) - for block_idx in range(1, num_blocks)] + for block_idx in 1:num_blocks] push!(stages, Chain(blocks...)) end return Chain(stages...) end +function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection, + classifier_fn) + # Build stages of the ResNet + stage_blocks = resnet_stages(get_layers, block_repeats, connection) + backbone = Chain(stem, stage_blocks) + # Build the classifier head + nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] + classifier = classifier_fn(nfeaturemaps) + return Chain(backbone, classifier) +end + function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity), cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, - prenorm::Bool = false, attn_fn = planes -> identity, + revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0, nclasses::Integer = 1000) @@ -296,17 +300,18 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, - activation, norm_layer, prenorm, attn_fn, + activation, norm_layer, revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = get_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = downsample_opt) elseif block_type == :bottleneck get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, reduction_factor, activation, norm_layer, - prenorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = get_stride, planes_fn = resnet_planes, + revnorm, attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = downsample_opt) else + # TODO: write better message when we have link to dev docs for resnet throw(ArgumentError("Unknown block type $block_type")) end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, @@ -318,17 +323,6 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...) end -function resnet(img_dims, stem, connection, get_layers, block_repeats::Vector{<:Integer}, - classifier_fn) - # Build stages of the ResNet - stage_blocks = resnet_stages(get_layers, block_repeats, connection) - backbone = Chain(stem, stage_blocks) - # Build the classifier head - nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] - classifier = classifier_fn(nfeaturemaps) - return Chain(backbone, classifier) -end - # block-layer configurations for ResNet-like models const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]), 34 => (:basicblock, [3, 4, 6, 3]), diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index ae25a73e8..824f2bbe9 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -27,7 +27,7 @@ end function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) _checkconfig(depth, keys(resnet_configs)) layers = resnet(resnet_configs[depth]...; inchannels, nclasses, - attn_fn = planes -> squeeze_excite(planes)) + attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNet", depth)) end @@ -70,7 +70,7 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt inchannels = 3, nclasses = 1000) _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width, - attn_fn = planes -> squeeze_excite(planes)) + attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 2bc996278..3db3a2ccd 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -26,7 +26,7 @@ include("normalise.jl") export prenorm, ChannelLayerNorm include("conv.jl") -export conv_norm, depthwise_sep_conv_bn, invertedresidual +export conv_norm, depthwise_sep_conv_norm, invertedresidual include("drop.jl") export DropBlock, DropPath, droppath_rates diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 6f8a31d47..8a195158e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,6 +1,6 @@ """ conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; - norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true, + norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init]) Create a convolution + batch normalization pair with activation. @@ -12,11 +12,11 @@ Create a convolution + batch normalization pair with activation. - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - `norm_layer`: the normalization layer used - - `prenorm`: set to `true` to place the batch norm before the convolution + - `revnorm`: set to `true` to place the batch norm before the convolution - `preact`: set to `true` to place the activation function before the batch norm - (only compatible with `prenorm = false`) + (only compatible with `revnorm = false`) - `use_bn`: set to `false` to disable batch normalization - (only compatible with `prenorm = false` and `preact = false`) + (only compatible with `revnorm = false` and `preact = false`) - `stride`: stride of the convolution kernel - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel @@ -24,16 +24,16 @@ Create a convolution + batch normalization pair with activation. - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; - norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true, + norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true, kwargs...) if !use_bn - if (preact || prenorm) + if (preact || revnorm) throw(ArgumentError("`preact` only supported with `use_bn = true`")) else return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] end end - if prenorm + if revnorm activations = (conv = activation, bn = identity) bnplanes = inplanes else @@ -41,15 +41,15 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu bnplanes = outplanes end if preact - if prenorm - throw(ArgumentError("`preact` and `prenorm` cannot be set at the same time")) + if revnorm + throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time")) else activations = (conv = activation, bn = identity) end end layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), norm_layer(bnplanes, activations.bn)] - return prenorm ? reverse(layers) : layers + return revnorm ? reverse(layers) : layers end function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes, @@ -60,7 +60,7 @@ end """ depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu; - prenorm = false, use_bn = (true, true), + revnorm = false, use_bn = (true, true), stride = 1, pad = 0, dilation = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. @@ -79,20 +79,20 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `prenorm`: set to `true` to place the batch norm before the convolution - - `use_bn`: a tuple of two booleans to specify whether to use batch normalization for the first and second convolution + - `revnorm`: set to `true` to place the batch norm before the convolution + - `use_bn`: a tuple of two booleans to specify whether to use normalization for the first and second convolution - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ -function depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu; - prenorm = false, use_bn = (true, true), - stride = 1, kwargs...) +function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; + norm_layer = BatchNorm, revnorm = false, use_norm = (true, true), + stride = 1, kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; - prenorm, use_bn = use_bn[1], stride, groups = inplanes, + norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; prenorm, + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, use_bn = use_bn[2])) end diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 0a74a24c0..1962ab0fb 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -10,6 +10,7 @@ produce a single output. Note that this is equivalent to - `output_size`: The size of the output after pooling. - `connection`: The connection type to use. """ -function AdaptiveMeanMaxPool(output_size = (1, 1); connection = +) +function AdaptiveMeanMaxPool(connection, output_size = (1, 1)) return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) end +AdaptiveMeanMaxPool(output_size::Tuple = (1, 1)) = AdaptiveMeanMaxPool(+, output_size) diff --git a/test/convnets.jl b/test/convnets.jl index 1b44676bc..258e037b6 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -23,13 +23,15 @@ end @testset "ResNet" begin # Tests for pretrained ResNets ## TODO: find a way to port pretrained models to the new ResNet API - # @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] + @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] + m = ResNet(sz) + @test size(m(x_224)) == (1000, 1) # if (ResNet, sz) in PRETRAINED_MODELS # @test acctest(ResNet(sz, pretrain = true)) # else # @test_throws ArgumentError ResNet(sz, pretrain = true) # end - # end + end @testset "resnet" begin @testset for block_fn in [:basicblock, :bottleneck] From 99eb25a4fee40b7da85c117d361830563fed3a15 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 22:34:02 +0530 Subject: [PATCH 61/64] Make all config dicts `const` and capitalise Also misc. formatting and cleanup --- src/convnets/convmixer.jl | 28 ++++++++++++------------- src/convnets/convnext.jl | 14 ++++++------- src/convnets/densenet.jl | 6 +++--- src/convnets/efficientnet.jl | 8 +++---- src/convnets/inception/xception.jl | 2 +- src/convnets/mobilenet/mobilenetv1.jl | 6 +++--- src/convnets/mobilenet/mobilenetv2.jl | 4 ++-- src/convnets/mobilenet/mobilenetv3.jl | 4 ++-- src/convnets/resnets/core.jl | 18 +++++++--------- src/convnets/resnets/resnet.jl | 4 ++-- src/convnets/resnets/resnext.jl | 4 ++-- src/convnets/resnets/seresnet.jl | 8 +++---- src/convnets/vgg.jl | 8 +++---- src/layers/conv.jl | 30 +++++++++++++-------------- src/mixers/core.jl | 8 +++---- src/mixers/gmlp.jl | 6 +++--- src/mixers/mlpmixer.jl | 6 +++--- src/mixers/resmlp.jl | 6 +++--- src/vit-based/vit.jl | 22 ++++++++++---------- test/convnets.jl | 2 +- 20 files changed, 95 insertions(+), 99 deletions(-) diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 6547ba4fb..aa3d144d2 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -28,15 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), return Chain(Chain(stem..., Chain(blocks)), head) end -convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7)), - :small => Dict(:planes => 768, :depth => 32, - :kernel_size => (7, 7), - :patch_size => (7, 7)), - :large => Dict(:planes => 1024, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7))) +const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7)), + :small => Dict(:planes => 768, :depth => 32, + :kernel_size => (7, 7), + :patch_size => (7, 7)), + :large => Dict(:planes => 1024, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7))) """ ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) @@ -57,11 +57,11 @@ end @functor ConvMixer function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) - _checkconfig(mode, keys(convmixer_configs)) - planes = convmixer_configs[mode][:planes] - depth = convmixer_configs[mode][:depth] - kernel_size = convmixer_configs[mode][:kernel_size] - patch_size = convmixer_configs[mode][:patch_size] + _checkconfig(mode, keys(CONVMIXER_CONFIGS)) + planes = CONVMIXER_CONFIGS[mode][:planes] + depth = CONVMIXER_CONFIGS[mode][:depth] + kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size] + patch_size = CONVMIXER_CONFIGS[mode][:patch_size] layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses) return ConvMixer(layers) diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 052192fec..e6ccee16a 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -66,11 +66,11 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 end # Configurations for ConvNeXt models -convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), - :small => ([3, 3, 27, 3], [96, 192, 384, 768]), - :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), - :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), - :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) +const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), + :small => ([3, 3, 27, 3], [96, 192, 384, 768]), + :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), + :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), + :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) struct ConvNeXt layers::Any @@ -94,8 +94,8 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - _checkconfig(mode, keys(convnext_configs)) - layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses) + _checkconfig(mode, keys(CONVNEXT_CONFIGS)) + layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses) return ConvNeXt(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index c41b4028b..332b5551f 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -140,7 +140,7 @@ end backbone(m::DenseNet) = m.layers[1] classifier(m::DenseNet) = m.layers[2] -const densenet_configs = Dict(121 => (6, 12, 24, 16), +const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16), 161 => (6, 12, 36, 24), 169 => (6, 12, 32, 32), 201 => (6, 12, 48, 32)) @@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. See also [`Metalhead.densenet`](#). """ function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - _checkconfig(config, keys(densenet_configs)) - model = DenseNet(densenet_configs[config]; nclasses = nclasses) + _checkconfig(config, keys(DENSENET_CONFIGS)) + model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses) if pretrain loadpretrain!(model, string("DenseNet", config)) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 122fd512a..4321e9443 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -59,7 +59,7 @@ end # e: expantion ratio # i: block input channels # o: block output channels -const efficientnet_block_configs = [ +const EFFICIENTNET_BLOCK_CONFIGS = [ # (n, k, s, e, i, o) (1, 3, 1, 1, 32, 16), (2, 3, 2, 6, 16, 24), @@ -73,7 +73,7 @@ const efficientnet_block_configs = [ # w: width scaling # d: depth scaling # r: image resolution -const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)), +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), :b1 => (240, (1.0, 1.1)), :b2 => (260, (1.1, 1.2)), :b3 => (300, (1.2, 1.4)), @@ -137,8 +137,8 @@ See also [`efficientnet`](#). - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) - _checkconfig(name, keys(efficientnet_global_configs)) - model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) + _checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS) pretrain && loadpretrain!(model, string("efficientnet-", name)) return model end diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index a585aadd4..3c6d8331a 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -35,7 +35,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, push!(layers, relu) append!(layers, depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, - use_bn = (false, false))) + use_norm = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index 22beaf86f..fe075d5ef 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -31,7 +31,7 @@ function mobilenetv1(width_mult, config; for _ in 1:nrepeats layer = dw ? depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : + stride = stride, pad = 1, bias = false) : conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, bias = false) append!(layers, layer) @@ -45,7 +45,7 @@ function mobilenetv1(width_mult, config; Dense(inchannels, nclasses))) end -const mobilenetv1_configs = [ +const MOBILENETV1_CONFIGS = [ # dw, c, s, r (false, 32, 2, 1), (true, 64, 1, 1), @@ -84,7 +84,7 @@ end function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) + layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv1")) end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 21c017b42..dd9eda012 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -46,7 +46,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla end # Layer configurations for MobileNetv2 -const mobilenetv2_configs = [ +const MOBILENETV2_CONFIGS = [ # t, c, n, s, a (1, 16, 1, 1, relu6), (6, 24, 2, 2, relu6), @@ -83,7 +83,7 @@ See also [`Metalhead.mobilenetv2`](#). """ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) + layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) pretrain && loadpretrain!(layers, string("MobileNetv2")) if pretrain loadpretrain!(layers, string("MobileNetv2")) diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 6bc444407..5a06f6be5 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -53,7 +53,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla end # Configurations for small and large mode for MobileNetv3 -mobilenetv3_configs = Dict(:small => [ +MOBILENETV3_CONFIGS = Dict(:small => [ # k, t, c, SE, a, s (3, 1, 16, 4, relu, 2), (3, 4.5, 24, nothing, relu, 2), @@ -115,7 +115,7 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = pretrain = false, nclasses = 1000) @assert mode in [:large, :small] "`mode` has to be either :large or :small" max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, + layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv3", mode)) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index e42b81c70..aa7309a9b 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -25,7 +25,7 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor - outplanes = planes * expansion_factor(basicblock) + outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, stride, pad = 1, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, @@ -67,7 +67,7 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, attn_fn = planes -> identity) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduction_factor - outplanes = planes * expansion_factor(bottleneck) + outplanes = planes * 4 conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, bias = false) conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, @@ -215,7 +215,7 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer drop_block = DropBlock(blockschedule[schedule_idx]) block = basicblock(inplanes, planes; stride, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) # inplanes increases by expansion after each block inplanes = planes * expansion return block, downsample @@ -248,7 +248,7 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer block = bottleneck(inplanes, planes; stride, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) # inplanes increases by expansion after each block inplanes = planes * expansion return block, downsample @@ -256,10 +256,6 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer return get_layers end -# Makes the main stages of the ResNet model. This is an internal function and should not be -# used by end-users. `block_fn` is a function that returns a single block of the ResNet. -# See `basicblock` and `bottleneck` for examples. A block must define a function -# `expansion(::typeof(block))` that returns the expansion factor of the block. function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) # Construct each stage stages = [] @@ -316,15 +312,15 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, connection$activation, get_layers, - block_repeats, classifier_fn) + return resnet((imsize..., inchannels), stem, get_layers, block_repeats, + connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...) end # block-layer configurations for ResNet-like models -const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]), +const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]), 34 => (:basicblock, [3, 4, 6, 3]), 50 => (:bottleneck, [3, 4, 6, 3]), 101 => (:bottleneck, [3, 4, 23, 3]), diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 7bebb0873..46c0826c2 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -23,8 +23,8 @@ end @functor ResNet function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - _checkconfig(depth, keys(resnet_configs)) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses) + _checkconfig(depth, keys(RESNET_CONFIGS)) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses) if pretrain loadpretrain!(layers, string("ResNet", depth)) end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 47e81d44d..8032df5ab 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -29,8 +29,8 @@ end function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width) + _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 824f2bbe9..05d842173 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -25,8 +25,8 @@ end (m::SEResNet)(x) = m.layers(x) function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - _checkconfig(depth, keys(resnet_configs)) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses, + _checkconfig(depth, keys(RESNET_CONFIGS)) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNet", depth)) @@ -68,8 +68,8 @@ end function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width, + _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 3a1a8ac10..ccfdd2cff 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -99,12 +99,12 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr return Chain(Chain(conv), class) end -const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], +const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) -const vgg_configs = Dict(11 => :A, +const VGG_CONFIGS = Dict(11 => :A, 13 => :B, 16 => :D, 19 => :E) @@ -153,8 +153,8 @@ See also [`VGG`](#). - `pretrain`: set to `true` to load pre-trained model weights for ImageNet """ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000) - _checkconfig(depth, keys(vgg_configs)) - model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]], + _checkconfig(depth, keys(VGG_CONFIGS)) + model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]], inchannels = 3, batchnorm = batchnorm, nclasses = nclasses, diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8a195158e..f5d94fbcb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,6 +1,6 @@ """ conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; - norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true, + norm_layer = BatchNorm, revnorm = false, preact = false, use_norm = true, stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init]) Create a convolution + batch normalization pair with activation. @@ -15,7 +15,7 @@ Create a convolution + batch normalization pair with activation. - `revnorm`: set to `true` to place the batch norm before the convolution - `preact`: set to `true` to place the activation function before the batch norm (only compatible with `revnorm = false`) - - `use_bn`: set to `false` to disable batch normalization + - `use_norm`: set to `false` to disable normalization (only compatible with `revnorm = false` and `preact = false`) - `stride`: stride of the convolution kernel - `pad`: padding of the convolution kernel @@ -24,11 +24,11 @@ Create a convolution + batch normalization pair with activation. - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; - norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true, + norm_layer = BatchNorm, revnorm = false, preact = false, use_norm = true, kwargs...) - if !use_bn + if !use_norm if (preact || revnorm) - throw(ArgumentError("`preact` only supported with `use_bn = true`")) + throw(ArgumentError("`preact` only supported with `use_norm = true`")) else return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] end @@ -59,17 +59,17 @@ function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes, end """ - depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu; - revnorm = false, use_bn = (true, true), + depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; + revnorm = false, use_norm = (true, true), stride = 1, pad = 0, dilation = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: - a `kernel_size` depthwise convolution from `inplanes => inplanes` - - a batch norm layer + `activation` (if `use_bn[1] == true`; otherwise `activation` is applied to the convolution output) + - a batch norm layer + `activation` (if `use_norm[1] == true`; otherwise `activation` is applied to the convolution output) - a `kernel_size` convolution from `inplanes => outplanes` - - a batch norm layer + `activation` (if `use_bn[2] == true`; otherwise `activation` is applied to the convolution output) + - a batch norm layer + `activation` (if `use_norm[2] == true`; otherwise `activation` is applied to the convolution output) See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). @@ -80,20 +80,20 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - `revnorm`: set to `true` to place the batch norm before the convolution - - `use_bn`: a tuple of two booleans to specify whether to use normalization for the first and second convolution + - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and second convolution - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; - norm_layer = BatchNorm, revnorm = false, use_norm = (true, true), - stride = 1, kwargs...) + norm_layer = BatchNorm, revnorm = false, + use_norm = (true, true), stride = 1, kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; - norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes, - kwargs...), + norm_layer, revnorm, use_norm = use_norm[1], stride, + groups = inplanes, kwargs...), conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, - use_bn = use_bn[2])) + use_norm = use_norm[2])) end """ diff --git a/src/mixers/core.jl b/src/mixers/core.jl index 6a55f048e..9f9d3b305 100644 --- a/src/mixers/core.jl +++ b/src/mixers/core.jl @@ -37,7 +37,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, end # Configurations for MLPMixer models -mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), - :base => Dict(:depth => 12, :planes => 768), - :large => Dict(:depth => 24, :planes => 1024), - :huge => Dict(:depth => 32, :planes => 1280)) +const MIXER_CONFIGS = Dict(:small => Dict(:depth => 8, :planes => 512), + :base => Dict(:depth => 12, :planes => 768), + :large => Dict(:depth => 24, :planes => 1024), + :huge => Dict(:depth => 32, :planes => 1280)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl index 4e681e9b4..9ebd2dce3 100644 --- a/src/mixers/gmlp.jl +++ b/src/mixers/gmlp.jl @@ -96,9 +96,9 @@ See also [`Metalhead.mlpmixer`](#). """ function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - _checkconfig(size, keys(mixer_configs)) - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, patch_size, embedplanes, drop_path_rate, depth, nclasses) return gMLP(layers) diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl index e3da17a23..7b6d4aa09 100644 --- a/src/mixers/mlpmixer.jl +++ b/src/mixers/mlpmixer.jl @@ -55,9 +55,9 @@ See also [`Metalhead.mlpmixer`](#). """ function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - _checkconfig(size, keys(mixer_configs)) - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, nclasses) return MLPMixer(layers) diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl index 38163702c..17e340310 100644 --- a/src/mixers/resmlp.jl +++ b/src/mixers/resmlp.jl @@ -58,9 +58,9 @@ See also [`Metalhead.mlpmixer`](#). """ function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - _checkconfig(size, keys(mixer_configs)) - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, drop_path_rate, depth, nclasses) return ResMLP(layers) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 93eba09ee..bcc5d43ba 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -62,15 +62,15 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end -vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), - :small => (depth = 12, embedplanes = 384, nheads = 6), - :base => (depth = 12, embedplanes = 768, nheads = 12), - :large => (depth = 24, embedplanes = 1024, nheads = 16), - :huge => (depth = 32, embedplanes = 1280, nheads = 16), - :giant => (depth = 40, embedplanes = 1408, nheads = 16, - mlp_ratio = 48 // 11), - :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, - mlp_ratio = 64 // 13)) +const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), + :small => (depth = 12, embedplanes = 384, nheads = 6), + :base => (depth = 12, embedplanes = 768, nheads = 12), + :large => (depth = 24, embedplanes = 1024, nheads = 16), + :huge => (depth = 32, embedplanes = 1280, nheads = 16), + :giant => (depth = 40, embedplanes = 1408, nheads = 16, + mlp_ratio = 48 // 11), + :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, + mlp_ratio = 64 // 13)) """ ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3, @@ -98,8 +98,8 @@ end function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) - _checkconfig(mode, keys(vit_configs)) - kwargs = vit_configs[mode] + _checkconfig(mode, keys(VIT_CONFIGS)) + kwargs = VIT_CONFIGS[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) return ViT(layers) end diff --git a/test/convnets.jl b/test/convnets.jl index 258e037b6..5740ed5c6 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -111,7 +111,7 @@ end @testset "EfficientNet" begin @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] # preferred image resolution scaling - r = Metalhead.efficientnet_global_configs[name][1] + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[name][1] x = rand(Float32, r, r, 3, 1) m = EfficientNet(name) @test size(m(x)) == (1000, 1) From 73131bff2b892f9b6518ceb962012fc4b243a57a Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 30 Jul 2022 11:48:14 +0530 Subject: [PATCH 62/64] Formatting, and some tweaks --- src/convnets/mobilenet/mobilenetv1.jl | 1 + src/convnets/mobilenet/mobilenetv2.jl | 3 +- src/convnets/mobilenet/mobilenetv3.jl | 67 +++++++++++++-------------- src/convnets/resnets/core.jl | 12 +++-- src/layers/Layers.jl | 20 ++++---- src/layers/attention.jl | 39 ++++++---------- src/layers/conv.jl | 3 +- src/layers/mlp.jl | 4 ++ src/vit-based/vit.jl | 10 ++-- 9 files changed, 76 insertions(+), 83 deletions(-) diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index fe075d5ef..fffa93a4d 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -45,6 +45,7 @@ function mobilenetv1(width_mult, config; Dense(inchannels, nclasses))) end +# Layer configurations for MobileNetv1 const MOBILENETV1_CONFIGS = [ # dw, c, s, r (false, 32, 2, 1), diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index dd9eda012..a97e7dda1 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -47,7 +47,7 @@ end # Layer configurations for MobileNetv2 const MOBILENETV2_CONFIGS = [ - # t, c, n, s, a + # t, c, n, s, a (1, 16, 1, 1, relu6), (6, 24, 2, 2, relu6), (6, 32, 3, 2, relu6), @@ -57,7 +57,6 @@ const MOBILENETV2_CONFIGS = [ (6, 320, 1, 1, relu6), ] -# Model definition for MobileNetv2 struct MobileNetv2 layers::Any end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 5a06f6be5..d8666c5f3 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -52,41 +52,40 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) end -# Configurations for small and large mode for MobileNetv3 -MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) +# Layer configurations for small and large models for MobileNetv3 +const MOBILENETV3_CONFIGS = Dict(:small => [ + # k, t, c, SE, a, s + (3, 1, 16, 4, relu, 2), + (3, 4.5, 24, nothing, relu, 2), + (3, 3.67, 24, nothing, relu, 1), + (5, 4, 40, 4, hardswish, 2), + (5, 6, 40, 4, hardswish, 1), + (5, 6, 40, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 2), + (5, 6, 96, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 1), + ], + :large => [ + # k, t, c, SE, a, s + (3, 1, 16, nothing, relu, 1), + (3, 4, 24, nothing, relu, 2), + (3, 3, 24, nothing, relu, 1), + (5, 3, 40, 4, relu, 2), + (5, 3, 40, 4, relu, 1), + (5, 3, 40, 4, relu, 1), + (3, 6, 80, nothing, hardswish, 2), + (3, 2.5, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 2), + (5, 6, 160, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 1), + ]) -# Model definition for MobileNetv3 struct MobileNetv3 layers::Any end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index aa7309a9b..03d96d6db 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -215,7 +215,8 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer drop_block = DropBlock(blockschedule[schedule_idx]) block = basicblock(inplanes, planes; stride, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) # inplanes increases by expansion after each block inplanes = planes * expansion return block, downsample @@ -248,7 +249,8 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer block = bottleneck(inplanes, planes; stride, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) # inplanes increases by expansion after each block inplanes = planes * expansion return block, downsample @@ -298,13 +300,15 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, + planes_fn = resnet_planes, downsample_tuple = downsample_opt) elseif block_type == :bottleneck get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, + planes_fn = resnet_planes, downsample_tuple = downsample_opt) else # TODO: write better message when we have link to dev docs for resnet diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 3db3a2ccd..04be476ff 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -16,6 +16,12 @@ include("../utilities.jl") include("attention.jl") export MHAttention +include("conv.jl") +export conv_norm, depthwise_sep_conv_norm, invertedresidual + +include("drop.jl") +export DropBlock, DropPath + include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens @@ -25,19 +31,13 @@ export mlp_block, gated_mlp_block, create_fc, create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm -include("conv.jl") -export conv_norm, depthwise_sep_conv_norm, invertedresidual - -include("drop.jl") -export DropBlock, DropPath, droppath_rates - -include("selayers.jl") -export squeeze_excite, effective_squeeze_excite +include("pool.jl") +export AdaptiveMeanMaxPool include("scale.jl") export LayerScale, inputscale -include("pool.jl") -export AdaptiveMeanMaxPool +include("selayers.jl") +export squeeze_excite, effective_squeeze_excite end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 7d8ee776d..e2276aa01 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,46 +1,33 @@ """ - MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_rate = 0., proj_dropout_rate = 0.) Multi-head self-attention layer. # Arguments - - `nheads`: Number of heads - - `qkv_layer`: layer to be used for getting the query, key and value - - `attn_drop_rate`: dropout rate after the self-attention layer - - `projection`: projection layer to be used after self-attention + - `planes`: number of input channels + - `nheads`: number of heads + - `qkv_bias`: whether to use bias in the layer to get the query, key and value + - `attn_dropout_rate`: dropout rate after the self-attention layer + - `proj_dropout_rate`: dropout rate after the projection layer """ struct MHAttention{P, Q, R} nheads::Int qkv_layer::P - attn_drop_rate::Q + attn_drop::Q projection::R end +@functor MHAttention -""" - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.) - -Multi-head self-attention layer. - -# Arguments - - - `planes`: number of input channels - - `nheads`: number of heads - - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_drop_rate`: dropout rate after the self-attention layer - - `proj_drop_rate`: dropout rate after the projection layer -""" function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_drop_rate = 0.0, proj_drop_rate = 0.0) + attn_dropout_rate = 0.0, proj_dropout_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop_rate = Dropout(attn_drop_rate) - proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) - return MHAttention(nheads, qkv_layer, attn_drop_rate, proj) + attn_drop = Dropout(attn_dropout_rate) + proj = Chain(Dense(planes, planes), Dropout(proj_dropout_rate)) + return MHAttention(nheads, qkv_layer, attn_drop, proj) end -@functor MHAttention - function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} nfeatures, seq_len, batch_size = size(x) x_reshaped = reshape(x, nfeatures, seq_len * batch_size) @@ -52,7 +39,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} seq_len * batch_size) query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) - attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) + attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) pre_projection = reshape(batched_mul(attention, value_reshaped), diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f5d94fbcb..02d80d67a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -52,8 +52,7 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes, - activation = identity; kwargs...) +function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 4a623c977..a3bdb0fb5 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -69,6 +69,10 @@ function create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1, "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" end flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer # Fully-connected layer fc = use_conv ? Conv((1, 1), inplanes => nclasses) : Dense(inplanes => nclasses) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index bcc5d43ba..3e74caed5 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -15,8 +15,8 @@ Transformer as used in the base ViT architecture. function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, MHAttention(planes, nheads; - attn_drop_rate = dropout_rate, - proj_drop_rate = dropout_rate)), +), + attn_dropout_rate = dropout_rate, + proj_dropout_rate = dropout_rate)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); dropout_rate)), +)) @@ -27,7 +27,7 @@ end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, - emb_drop_rate = 0.1, pool = :class, nclasses = 1000) + emb_dropout_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -48,14 +48,14 @@ Creates a Vision Transformer (ViT) model. """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, - emb_drop_rate = 0.1, pool = :class, nclasses = 1000) + emb_dropout_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), ViPosEmbedding(embedplanes, npatches + 1), - Dropout(emb_drop_rate), + Dropout(emb_dropout_rate), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), From 73df024a138eae77f418c0863e3c9eb3d2f7eaad Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 30 Jul 2022 18:55:33 +0530 Subject: [PATCH 63/64] Add WideResNet --- .github/workflows/CI.yml | 3 ++- src/Metalhead.jl | 2 +- src/convnets/resnets/resnet.jl | 40 ++++++++++++++++++++++++++++++++++ src/layers/conv.jl | 3 ++- src/vit-based/vit.jl | 3 ++- test/convnets.jl | 9 ++++++++ 6 files changed, 56 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8fa607912..d43e61da4 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,7 +29,8 @@ jobs: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - '["EfficientNet"]' - - '[r"ResNet", r"ResNeXt"]' + - 'r"/^ResNet\z/"' + - '[r"ResNeXt", r"SEResNet"]' - '"Inception"' - '"DenseNet"' - '["ConvNeXt", "ConvMixer"]' diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 3c8469dd2..374f28615 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -60,7 +60,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, - SEResNet, SEResNeXt, + WideResNet, SEResNet, SEResNeXt, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 46c0826c2..fac7e7415 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -35,3 +35,43 @@ end backbone(m::ResNet) = m.layers[1] classifier(m::ResNet) = m.layers[2] + +""" + WideResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a Wide ResNet model with the specified depth. The model is the same as ResNet +except for the bottleneck number of channels which is twice larger in every block. +The number of channels in outer 1x1 convolutions is the same. +((reference)[https://arxiv.org/abs/1605.07146]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the Wide ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `WideResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct WideResNet + layers::Any +end +@functor WideResNet + +function WideResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + _checkconfig(depth, [50, 101]) + layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("WideResNet", depth)) + end + return WideResNet(layers) +end + +(m::WideResNet)(x) = m.layers(x) + +backbone(m::WideResNet) = m.layers[1] +classifier(m::WideResNet) = m.layers[2] diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 02d80d67a..5610d3be2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -52,7 +52,8 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...) +function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; + kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 3e74caed5..1fece2191 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -16,7 +16,8 @@ function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rat layers = [Chain(SkipConnection(prenorm(planes, MHAttention(planes, nheads; attn_dropout_rate = dropout_rate, - proj_dropout_rate = dropout_rate)), +), + proj_dropout_rate = dropout_rate)), + +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); dropout_rate)), +)) diff --git a/test/convnets.jl b/test/convnets.jl index 5740ed5c6..40f5ec75a 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -56,6 +56,15 @@ end end end end + + @testset "WideResNet" begin + @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end + end end @testset "ResNeXt" begin From 72cd4a9bcc03840006a7b3076baf6a8b10cbee74 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Tue, 2 Aug 2022 08:19:24 +0530 Subject: [PATCH 64/64] Don't use globals --- .github/workflows/CI.yml | 2 +- src/convnets/resnets/core.jl | 27 +++++++++++++++++---------- test/convnets.jl | 5 +++++ 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d43e61da4..8de5bd6e0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,7 +29,7 @@ jobs: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - '["EfficientNet"]' - - 'r"/^ResNet\z/"' + - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' - '"Inception"' - '"DenseNet"' diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 03d96d6db..329663c13 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -51,7 +51,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512. - `stride`: the stride of the block - `cardinality`: the number of groups in the convolution. - `base_width`: the number of output feature maps for each convolutional group. - - `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first + - `reduction_factor`: the factor by which the input feature maps are reduced before the first convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. @@ -190,7 +190,10 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, return Chain(conv1, bn1, stempool) end -resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1) +function resnet_planes(block_repeats::Vector{<:Integer}) + return Iterators.flatten((64 * 2^(stage_idx - 1) for _ in 1:stages) + for (stage_idx, stages) in enumerate(block_repeats)) +end function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, reduction_factor::Integer = 1, expansion::Integer = 1, @@ -201,24 +204,25 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) - planes = planes_fn(stage_idx) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) downsample_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_tuple[1] : downsample_tuple[2] - # DropBlock, DropPath both take in rates based on a linear scaling schedule - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) block = basicblock(inplanes, planes; stride, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) - # inplanes increases by expansion after each block - inplanes = planes * expansion return block, downsample end return get_layers @@ -234,9 +238,14 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) - planes = planes_fn(stage_idx) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion # `resnet_stride` is a callback that the user can tweak to change the stride of the # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) @@ -251,8 +260,6 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer attn_fn, drop_path, drop_block) downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) - # inplanes increases by expansion after each block - inplanes = planes * expansion return block, downsample end return get_layers diff --git a/test/convnets.jl b/test/convnets.jl index 40f5ec75a..e62b14299 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -63,6 +63,11 @@ end @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(ResNet(sz, pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) + end end end end