diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c383351bc..8de5bd6e0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,16 +27,15 @@ jobs: - x64 suite: - '["AlexNet", "VGG"]' - - '["GoogLeNet", "SqueezeNet"]' - - '["EfficientNet", "MobileNet"]' - - '[r"/*/ResNet*", "ResNeXt"]' - - 'r"/*/Inception/Inceptionv*"' - - '["InceptionResNetv2", "Xception"]' + - '["GoogLeNet", "SqueezeNet", "MobileNet"]' + - '["EfficientNet"]' + - 'r"/*/ResNet*"' + - '[r"ResNeXt", r"SEResNet"]' + - '"Inception"' - '"DenseNet"' - - '"ConvNeXt"' - - '"ConvMixer"' - - '"ViT"' - - '"Other"' + - '["ConvNeXt", "ConvMixer"]' + - 'r"ViTs"' + - 'r"Mixers"' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index afdffb3fd..6230fdbee 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,15 @@ version = "0.8.0-DEV" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/docs/make.jl b/docs/make.jl index db03f1d76..f5d29f7e9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Pkg -Pkg.develop(path = "..") +Pkg.develop(; path = "..") using Publish using Artifacts, LazyArtifacts @@ -13,5 +13,5 @@ p = Publish.Project(Metalhead) function build_and_deploy(label) rm(label; recursive = true, force = true) - deploy(Metalhead; root = "/Metalhead.jl", label = label) + return deploy(Metalhead; root = "/Metalhead.jl", label = label) end diff --git a/docs/serve.jl b/docs/serve.jl index 763e77e93..bf4a51179 100644 --- a/docs/serve.jl +++ b/docs/serve.jl @@ -1,6 +1,6 @@ using Pkg -Pkg.develop(path = "..") +Pkg.develop(; path = "..") using Revise using Publish diff --git a/src/Metalhead.jl b/src/Metalhead.jl index f391c0c66..374f28615 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -7,6 +7,7 @@ using BSON using Artifacts, LazyArtifacts using Statistics using MLUtils +using PartialFunctions using Random import Functors @@ -20,23 +21,38 @@ using .Layers # CNN models include("convnets/alexnet.jl") include("convnets/vgg.jl") -include("convnets/inception.jl") -include("convnets/googlenet.jl") -include("convnets/resnet.jl") -include("convnets/resnext.jl") +## ResNets +include("convnets/resnets/core.jl") +include("convnets/resnets/resnet.jl") +include("convnets/resnets/resnext.jl") +include("convnets/resnets/seresnet.jl") +## Inceptions +include("convnets/inception/googlenet.jl") +include("convnets/inception/inceptionv3.jl") +include("convnets/inception/inceptionv4.jl") +include("convnets/inception/inceptionresnetv2.jl") +include("convnets/inception/xception.jl") +## MobileNets +include("convnets/mobilenet/mobilenetv1.jl") +include("convnets/mobilenet/mobilenetv2.jl") +include("convnets/mobilenet/mobilenetv3.jl") +## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") -include("convnets/mobilenet.jl") include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") -# Other models -include("other/mlpmixer.jl") +# Mixers +include("mixers/core.jl") +include("mixers/mlpmixer.jl") +include("mixers/resmlp.jl") +include("mixers/gmlp.jl") -# ViT-based models +# ViTs include("vit-based/vit.jl") +# Load pretrained weights include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, @@ -44,14 +60,15 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, + WideResNet, SEResNet, SEResNeXt, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, +for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, + :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 405272dd2..8ff65ffef 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -24,7 +24,6 @@ function alexnet(; nclasses = 1000) Dropout(0.5), Dense(4096, 4096, relu), Dense(4096, nclasses))) - return layers end @@ -46,15 +45,16 @@ See also [`alexnet`](#). struct AlexNet layers::Any end +@functor AlexNet function AlexNet(; pretrain = false, nclasses = 1000) layers = alexnet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "AlexNet") + if pretrain + loadpretrain!(layers, "AlexNet") + end return AlexNet(layers) end -@functor AlexNet - (m::AlexNet)(x) = m.layers(x) backbone(m::AlexNet) = m.layers[1] diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index f70473ca5..aa3d144d2 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -9,7 +9,7 @@ Creates a ConvMixer model. - `planes`: number of planes in the output of each block - `depth`: number of layers - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `kernel_size`: kernel size of the convolutional layers - `patch_size`: size of the patches - `activation`: activation function used after the convolutional layers @@ -17,24 +17,26 @@ Creates a ConvMixer model. """ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000) - stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, - stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation; - preact = true, groups = planes, - pad = SamePad())), +), - conv_bn((1, 1), planes, planes, activation; preact = true)...) + stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, + stride = patch_size[1]) + blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; + preact = true, groups = planes, + pad = SamePad())), +), + conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses)) return Chain(Chain(stem..., Chain(blocks)), head) end -convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9), - :patch_size => (7, 7)), - :small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7), - :patch_size => (7, 7)), - :large => Dict(:planes => 1024, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7))) +const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7)), + :small => Dict(:planes => 768, :depth => 32, + :kernel_size => (7, 7), + :patch_size => (7, 7)), + :large => Dict(:planes => 1024, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7))) """ ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) @@ -45,26 +47,26 @@ Creates a ConvMixer model. # Arguments - `mode`: the mode of the model, either `:base`, `:small` or `:large` - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `activation`: activation function used after the convolutional layers - `nclasses`: number of classes in the output """ struct ConvMixer layers::Any end +@functor ConvMixer function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) - planes = convmixer_config[mode][:planes] - depth = convmixer_config[mode][:depth] - kernel_size = convmixer_config[mode][:kernel_size] - patch_size = convmixer_config[mode][:patch_size] + _checkconfig(mode, keys(CONVMIXER_CONFIGS)) + planes = CONVMIXER_CONFIGS[mode][:planes] + depth = CONVMIXER_CONFIGS[mode][:depth] + kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size] + patch_size = CONVMIXER_CONFIGS[mode][:patch_size] layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses) return ConvMixer(layers) end -@functor ConvMixer - (m::ConvMixer)(x) = m.layers(x) backbone(m::ConvMixer) = m.layers[1] diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 3fef58d1d..e6ccee16a 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -4,7 +4,7 @@ Creates a single block of ConvNeXt. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `planes`: number of input channels. - `drop_path_rate`: Stochastic depth rate. @@ -27,7 +27,7 @@ end Creates the layers for a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - `inchannels`: number of input channels. - `depths`: list with configuration for depth of each block @@ -39,50 +39,43 @@ Creates the layers for a ConvNeXt model. """ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert length(depths)==length(planes) "`planes` should have exactly one value for each block" - + @assert length(depths) == length(planes) + "`planes` should have exactly one value for each block" downsample_layers = [] stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4), - ChannelLayerNorm(planes[1]; ϵ = 1.0f-6)) + ChannelLayerNorm(planes[1])) push!(downsample_layers, stem) for m in 1:(length(depths) - 1) - downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6), + downsample_layer = Chain(ChannelLayerNorm(planes[m]), Conv((2, 2), planes[m] => planes[m + 1]; stride = 2)) push!(downsample_layers, downsample_layer) end - stages = [] - dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths)) + dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths)) cur = 0 - for i in 1:length(depths) + for i in eachindex(depths) push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]]) cur += depths[i] end - backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages)))) head = Chain(GlobalMeanPool(), MLUtils.flatten, LayerNorm(planes[end]), Dense(planes[end], nclasses)) - return Chain(Chain(backbone), head) end # Configurations for ConvNeXt models -convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3], - :planes => [96, 192, 384, 768]), - :small => Dict(:depths => [3, 3, 27, 3], - :planes => [96, 192, 384, 768]), - :base => Dict(:depths => [3, 3, 27, 3], - :planes => [128, 256, 512, 1024]), - :large => Dict(:depths => [3, 3, 27, 3], - :planes => [192, 384, 768, 1536]), - :xlarge => Dict(:depths => [3, 3, 27, 3], - :planes => [256, 512, 1024, 2048])) +const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), + :small => ([3, 3, 27, 3], [96, 192, 384, 768]), + :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), + :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), + :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) struct ConvNeXt layers::Any end +@functor ConvNeXt """ ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000) @@ -90,9 +83,9 @@ end Creates a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) -# Arguments: +# Arguments - - `inchannels`: The number of channels in the input. The default value is 3. + - `inchannels`: The number of channels in the input. - `drop_path_rate`: Stochastic depth rate. - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239) - `nclasses`: number of output classes @@ -101,16 +94,12 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - @assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))" - depths = convnext_configs[mode][:depths] - planes = convnext_configs[mode][:planes] - layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses) + _checkconfig(mode, keys(CONVNEXT_CONFIGS)) + layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses) return ConvNeXt(layers) end (m::ConvNeXt)(x) = m.layers(x) -@functor ConvNeXt - backbone(m::ConvNeXt) = m.layers[1] classifier(m::ConvNeXt) = m.layers[2] diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 0b318dbf3..332b5551f 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -12,10 +12,10 @@ Create a Densenet bottleneck layer """ function dense_bottleneck(inplanes, outplanes) inner_channels = 4 * outplanes - return SkipConnection(Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, - rev = true)..., - conv_bn((3, 3), inner_channels, outplanes; pad = 1, - bias = false, rev = true)...), + return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false, + revnorm = true)..., + conv_norm((3, 3), inner_channels, outplanes; pad = 1, + bias = false, revnorm = true)...), cat_channels) end @@ -31,7 +31,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ function transition(inplanes, outplanes) - return Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., + return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)..., MeanPool((2, 2))) end @@ -70,7 +70,7 @@ Create a DenseNet model """ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000) layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)) + append!(layers, conv_norm((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)) push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) @@ -100,7 +100,8 @@ Create a DenseNet model - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ -function densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000) +function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, + nclasses = 1000) where {N} return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; reduction = reduction, nclasses = nclasses) end @@ -139,14 +140,14 @@ end backbone(m::DenseNet) = m.layers[1] classifier(m::DenseNet) = m.layers[2] -const densenet_config = Dict(121 => (6, 12, 24, 16), - 161 => (6, 12, 36, 24), - 169 => (6, 12, 32, 32), - 201 => (6, 12, 48, 32)) +const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16), + 161 => (6, 12, 36, 24), + 169 => (6, 12, 32, 32), + 201 => (6, 12, 48, 32)) """ DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - DenseNet(transition_config::NTuple{N,Integer}) + DenseNet(transition_configs::NTuple{N,Integer}) Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201) ([reference](https://arxiv.org/abs/1608.06993)). @@ -159,8 +160,10 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. See also [`Metalhead.densenet`](#). """ function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - @assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))." - model = DenseNet(densenet_config[config]; nclasses = nclasses) - pretrain && loadpretrain!(model, string("DenseNet", config)) + _checkconfig(config, keys(DENSENET_CONFIGS)) + model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses) + if pretrain + loadpretrain!(model, string("DenseNet", config)) + end return model end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 1465eb238..4321e9443 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -1,41 +1,40 @@ """ - efficientnet(scalings, block_config; + efficientnet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). # Arguments -- `scalings`: global width and depth scaling (given as a tuple) -- `block_config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - `n`: number of block repetitions (will be scaled by global depth scaling) - - `k`: kernel size - - `s`: kernel stride - - `e`: expansion ratio - - `i`: block input channels (will be scaled by global width scaling) - - `o`: block output channels (will be scaled by global width scaling) -- `inchannels`: number of input channels -- `nclasses`: number of output classes -- `max_width`: maximum number of output channels before the fully connected - classification blocks + - `scalings`: global width and depth scaling (given as a tuple) + + - `block_configs`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `n`: number of block repetitions (will be scaled by global depth scaling) + + `k`: kernel size + + `s`: kernel stride + + `e`: expansion ratio + + `i`: block input channels (will be scaled by global width scaling) + + `o`: block output channels (will be scaled by global width scaling) + - `inchannels`: number of input channels + - `nclasses`: number of output classes + - `max_width`: maximum number of output channels before the fully connected + classification blocks """ -function efficientnet(scalings, block_config; +function efficientnet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) - out_channels = _round_channels(scalew(32), 8) - stem = conv_bn((3, 3), inchannels, out_channels, swish; - bias = false, stride = 2, pad = SamePad()) - + stem = conv_norm((3, 3), inchannels, out_channels, swish; + bias = false, stride = 2, pad = SamePad()) blocks = [] - for (n, k, s, e, i, o) in block_config + for (n, k, s, e, i, o) in block_configs in_channels = _round_channels(scalew(i), 8) out_channels = _round_channels(scalew(o), 8) repeats = scaled(n) - push!(blocks, invertedresidual(k, in_channels, in_channels * e, out_channels, swish; stride = s, reduction = 4)) @@ -46,13 +45,10 @@ function efficientnet(scalings, block_config; end end blocks = Chain(blocks...) - head_out_channels = _round_channels(max_width, 8) - head = conv_bn((1, 1), out_channels, head_out_channels, swish; - bias = false, pad = SamePad()) - + head = conv_norm((1, 1), out_channels, head_out_channels, swish; + bias = false, pad = SamePad()) top = Dense(head_out_channels, nclasses) - return Chain(Chain([stem..., blocks, head...]), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top)) end @@ -63,39 +59,37 @@ end # e: expantion ratio # i: block input channels # o: block output channels -const efficientnet_block_configs = [ -# (n, k, s, e, i, o) - (1, 3, 1, 1, 32, 16), - (2, 3, 2, 6, 16, 24), - (2, 5, 2, 6, 24, 40), - (3, 3, 2, 6, 40, 80), - (3, 5, 1, 6, 80, 112), +const EFFICIENTNET_BLOCK_CONFIGS = [ + # (n, k, s, e, i, o) + (1, 3, 1, 1, 32, 16), + (2, 3, 2, 6, 16, 24), + (2, 5, 2, 6, 24, 40), + (3, 3, 2, 6, 40, 80), + (3, 5, 1, 6, 80, 112), (4, 5, 2, 6, 112, 192), - (1, 3, 1, 6, 192, 320) + (1, 3, 1, 6, 192, 320), ] # w: width scaling # d: depth scaling # r: image resolution -const efficientnet_global_configs = Dict( -# ( r, ( w, d)) - :b0 => (224, (1.0, 1.0)), - :b1 => (240, (1.0, 1.1)), - :b2 => (260, (1.1, 1.2)), - :b3 => (300, (1.2, 1.4)), - :b4 => (380, (1.4, 1.8)), - :b5 => (456, (1.6, 2.2)), - :b6 => (528, (1.8, 2.6)), - :b7 => (600, (2.0, 3.1)), - :b8 => (672, (2.2, 3.6)) -) +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6))) struct EfficientNet - layers::Any + layers::Any end +@functor EfficientNet """ - EfficientNet(scalings, block_config; + EfficientNet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). @@ -103,31 +97,28 @@ See also [`efficientnet`](#). # Arguments -- `scalings`: global width and depth scaling (given as a tuple) -- `block_config`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - `n`: number of block repetitions (will be scaled by global depth scaling) - - `k`: kernel size - - `s`: kernel stride - - `e`: expansion ratio - - `i`: block input channels (will be scaled by global width scaling) - - `o`: block output channels (will be scaled by global width scaling) -- `inchannels`: number of input channels -- `nclasses`: number of output classes -- `max_width`: maximum number of output channels before the fully connected - classification blocks + - `scalings`: global width and depth scaling (given as a tuple) + + - `block_configs`: configuration for each inverted residual block, + given as a vector of tuples with elements: + + + `n`: number of block repetitions (will be scaled by global depth scaling) + + `k`: kernel size + + `s`: kernel stride + + `e`: expansion ratio + + `i`: block input channels (will be scaled by global width scaling) + + `o`: block output channels (will be scaled by global width scaling) + - `inchannels`: number of input channels + - `nclasses`: number of output classes + - `max_width`: maximum number of output channels before the fully connected + classification blocks """ -function EfficientNet(scalings, block_config; +function EfficientNet(scalings, block_configs; inchannels = 3, nclasses = 1000, max_width = 1280) - layers = efficientnet(scalings, block_config; - inchannels = inchannels, - nclasses = nclasses, - max_width = max_width) - return EfficientNet(layers) + layers = efficientnet(scalings, block_configs; inchannels, nclasses, max_width) + return EfficientNet(layers) end -@functor EfficientNet - (m::EfficientNet)(x) = m.layers(x) backbone(m::EfficientNet) = m.layers[1] @@ -141,16 +132,13 @@ See also [`efficientnet`](#). # Arguments -- `name`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) -- `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `name`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) - @assert name in keys(efficientnet_global_configs) - "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" - - model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) + _checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS) pretrain && loadpretrain!(model, string("efficientnet-", name)) - return model end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl deleted file mode 100644 index c3fd39f5e..000000000 --- a/src/convnets/inception.jl +++ /dev/null @@ -1,593 +0,0 @@ -## Inceptionv3 - -""" - inceptionv3_a(inplanes, pool_proj) - -Create an Inception-v3 style-A module -(ref: Fig. 5 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps - - `pool_proj`: the number of output feature maps for the pooling projection -""" -function inceptionv3_a(inplanes, pool_proj) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 64)) - branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)..., - conv_bn((5, 5), 48, 64; pad = 2)...) - branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), inplanes, pool_proj)...) - return Parallel(cat_channels, - branch1x1, branch5x5, branch3x3, branch_pool) -end - -""" - inceptionv3_b(inplanes) - -Create an Inception-v3 style-B module -(ref: Fig. 10 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_b(inplanes) - branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2)) - branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; stride = 2)...) - branch_pool = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, - branch3x3_1, branch3x3_2, branch_pool) -end - -""" - inceptionv3_c(inplanes, inner_planes, n = 7) - -Create an Inception-v3 style-C module -(ref: Fig. 6 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps - - `inner_planes`: the number of output feature maps within each branch - - `n`: the "grid size" (kernel size) for the convolution layers -""" -function inceptionv3_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 192)) - branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)..., - conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...) - branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)..., - conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_bn((1, n), inner_planes, 192; pad = (0, 3))...) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), inplanes, 192)...) - return Parallel(cat_channels, - branch1x1, branch7x7_1, branch7x7_2, branch_pool) -end - -""" - inceptionv3_d(inplanes) - -Create an Inception-v3 style-D module -(ref: [pytorch](https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/inception.py#L322)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_d(inplanes) - branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)..., - conv_bn((3, 3), 192, 320; stride = 2)...) - branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)..., - conv_bn((1, 7), 192, 192; pad = (0, 3))..., - conv_bn((7, 1), 192, 192; pad = (3, 0))..., - conv_bn((3, 3), 192, 192; stride = 2)...) - branch_pool = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, - branch3x3, branch7x7x3, branch_pool) -end - -""" - inceptionv3_e(inplanes) - -Create an Inception-v3 style-E module -(ref: Fig. 7 in [paper](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `inplanes`: number of input feature maps -""" -function inceptionv3_e(inplanes) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 320)) - branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384)) - branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))) - branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))) - branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)..., - conv_bn((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))) - branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))) - branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), inplanes, 192)...) - return Parallel(cat_channels, - branch1x1, - Chain(branch3x3_1, - Parallel(cat_channels, - branch3x3_1a, branch3x3_1b)), - Chain(branch3x3_2, - Parallel(cat_channels, - branch3x3_2a, branch3x3_2b)), - branch_pool) -end - -""" - inceptionv3(; nclasses = 1000) - -Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - -# Arguments - - - `nclasses`: the number of output classes -""" -function inceptionv3(; nclasses = 1000) - layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., - MaxPool((3, 3); stride = 2), - conv_bn((1, 1), 64, 80)..., - conv_bn((3, 3), 80, 192)..., - MaxPool((3, 3); stride = 2), - inceptionv3_a(192, 32), - inceptionv3_a(256, 64), - inceptionv3_a(288, 64), - inceptionv3_b(288), - inceptionv3_c(768, 128), - inceptionv3_c(768, 160), - inceptionv3_c(768, 160), - inceptionv3_c(768, 192), - inceptionv3_d(768), - inceptionv3_e(1280), - inceptionv3_e(2048)), - Chain(AdaptiveMeanPool((1, 1)), - Dropout(0.2), - MLUtils.flatten, - Dense(2048, nclasses))) - return layer -end - -""" - Inceptionv3(; pretrain = false, nclasses = 1000) - -Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). -See also [`inceptionv3`](#). - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -!!! warning - - `Inceptionv3` does not currently support pretrained weights. -""" -struct Inceptionv3 - layers::Any -end - -function Inceptionv3(; pretrain = false, nclasses = 1000) - layers = inceptionv3(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "Inceptionv3") - return Inceptionv3(layers) -end - -@functor Inceptionv3 - -(m::Inceptionv3)(x) = m.layers(x) - -backbone(m::Inceptionv3) = m.layers[1] -classifier(m::Inceptionv3) = m.layers[2] - -## Inceptionv4 - -function mixed_3a() - return Parallel(cat_channels, - MaxPool((3, 3); stride = 2), - Chain(conv_bn((3, 3), 64, 96; stride = 2)...)) -end - -function mixed_4a() - return Parallel(cat_channels, - Chain(conv_bn((1, 1), 160, 64)..., - conv_bn((3, 3), 64, 96)...), - Chain(conv_bn((1, 1), 160, 64)..., - conv_bn((1, 7), 64, 64; pad = (0, 3))..., - conv_bn((7, 1), 64, 64; pad = (3, 0))..., - conv_bn((3, 3), 64, 96)...)) -end - -function mixed_5a() - return Parallel(cat_channels, - Chain(conv_bn((3, 3), 192, 192; stride = 2)...), - MaxPool((3, 3); stride = 2)) -end - -function inceptionv4_a() - branch1 = Chain(conv_bn((1, 1), 384, 96)...) - branch2 = Chain(conv_bn((1, 1), 384, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)...) - branch3 = Chain(conv_bn((1, 1), 384, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_bn((1, 1), 384, 96)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_a() - branch1 = Chain(conv_bn((3, 3), 384, 384; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 384, 192)..., - conv_bn((3, 3), 192, 224; pad = 1)..., - conv_bn((3, 3), 224, 256; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_b() - branch1 = Chain(conv_bn((1, 1), 1024, 384)...) - branch2 = Chain(conv_bn((1, 1), 1024, 192)..., - conv_bn((1, 7), 192, 224; pad = (0, 3))..., - conv_bn((7, 1), 224, 256; pad = (3, 0))...) - branch3 = Chain(conv_bn((1, 1), 1024, 192)..., - conv_bn((7, 1), 192, 192; pad = (0, 3))..., - conv_bn((1, 7), 192, 224; pad = (3, 0))..., - conv_bn((7, 1), 224, 224; pad = (0, 3))..., - conv_bn((1, 7), 224, 256; pad = (3, 0))...) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_bn((1, 1), 1024, 128)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function reduction_b() - branch1 = Chain(conv_bn((1, 1), 1024, 192)..., - conv_bn((3, 3), 192, 192; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 1024, 256)..., - conv_bn((1, 7), 256, 256; pad = (0, 3))..., - conv_bn((7, 1), 256, 320; pad = (3, 0))..., - conv_bn((3, 3), 320, 320; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function inceptionv4_c() - branch1 = Chain(conv_bn((1, 1), 1536, 256)...) - branch2 = Chain(conv_bn((1, 1), 1536, 384)..., - Parallel(cat_channels, - Chain(conv_bn((1, 3), 384, 256; pad = (0, 1))...), - Chain(conv_bn((3, 1), 384, 256; pad = (1, 0))...))) - branch3 = Chain(conv_bn((1, 1), 1536, 384)..., - conv_bn((3, 1), 384, 448; pad = (1, 0))..., - conv_bn((1, 3), 448, 512; pad = (0, 1))..., - Parallel(cat_channels, - Chain(conv_bn((1, 3), 512, 256; pad = (0, 1))...), - Chain(conv_bn((3, 1), 512, 256; pad = (1, 0))...))) - branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_bn((1, 1), 1536, 256)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -""" - inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) - -Create an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) - body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., - mixed_3a(), - mixed_4a(), - mixed_5a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - inceptionv4_a(), - reduction_a(), # mixed_6a - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - inceptionv4_b(), - reduction_b(), # mixed_7a - inceptionv4_c(), - inceptionv4_c(), - inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) - return Chain(body, head) -end - -""" - Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - -Creates an Inceptionv4 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `Inceptionv4` does not currently support pretrained weights. -""" -struct Inceptionv4 - layers::Any -end - -function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - layers = inceptionv4(; inchannels, dropout, nclasses) - pretrain && loadpretrain!(layers, "Inceptionv4") - return Inceptionv4(layers) -end - -@functor Inceptionv4 - -(m::Inceptionv4)(x) = m.layers(x) - -backbone(m::Inceptionv4) = m.layers[1] -classifier(m::Inceptionv4) = m.layers[2] - -## Inception-ResNetv2 - -function mixed_5b() - branch1 = Chain(conv_bn((1, 1), 192, 96)...) - branch2 = Chain(conv_bn((1, 1), 192, 48)..., - conv_bn((5, 5), 48, 64; pad = 2)...) - branch3 = Chain(conv_bn((1, 1), 192, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) - branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), - conv_bn((1, 1), 192, 64)...) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function block35(scale = 1.0f0) - branch1 = Chain(conv_bn((1, 1), 320, 32)...) - branch2 = Chain(conv_bn((1, 1), 320, 32)..., - conv_bn((3, 3), 32, 32; pad = 1)...) - branch3 = Chain(conv_bn((1, 1), 320, 32)..., - conv_bn((3, 3), 32, 48; pad = 1)..., - conv_bn((3, 3), 48, 64; pad = 1)...) - branch4 = Chain(conv_bn((1, 1), 128, 320)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), - branch4, inputscale(scale; activation = relu)), +) -end - -function mixed_6a() - branch1 = Chain(conv_bn((3, 3), 320, 384; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 320, 256)..., - conv_bn((3, 3), 256, 256; pad = 1)..., - conv_bn((3, 3), 256, 384; stride = 2)...) - branch3 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3) -end - -function block17(scale = 1.0f0) - branch1 = Chain(conv_bn((1, 1), 1088, 192)...) - branch2 = Chain(conv_bn((1, 1), 1088, 128)..., - conv_bn((1, 7), 128, 160; pad = (0, 3))..., - conv_bn((7, 1), 160, 192; pad = (3, 0))...) - branch3 = Chain(conv_bn((1, 1), 384, 1088)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), - branch3, inputscale(scale; activation = relu)), +) -end - -function mixed_7a() - branch1 = Chain(conv_bn((1, 1), 1088, 256)..., - conv_bn((3, 3), 256, 384; stride = 2)...) - branch2 = Chain(conv_bn((1, 1), 1088, 256)..., - conv_bn((3, 3), 256, 288; stride = 2)...) - branch3 = Chain(conv_bn((1, 1), 1088, 256)..., - conv_bn((3, 3), 256, 288; pad = 1)..., - conv_bn((3, 3), 288, 320; stride = 2)...) - branch4 = MaxPool((3, 3); stride = 2) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) -end - -function block8(scale = 1.0f0; activation = identity) - branch1 = Chain(conv_bn((1, 1), 2080, 192)...) - branch2 = Chain(conv_bn((1, 1), 2080, 192)..., - conv_bn((1, 3), 192, 224; pad = (0, 1))..., - conv_bn((3, 1), 224, 256; pad = (1, 0))...) - branch3 = Chain(conv_bn((1, 1), 448, 2080)...) - return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), - branch3, inputscale(scale; activation)), +) -end - -""" - inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) - -Creates an InceptionResNetv2 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) - body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., - MaxPool((3, 3); stride = 2), - conv_bn((3, 3), 64, 80)..., - conv_bn((3, 3), 80, 192)..., - MaxPool((3, 3); stride = 2), - mixed_5b(), - [block35(0.17f0) for _ in 1:10]..., - mixed_6a(), - [block17(0.10f0) for _ in 1:20]..., - mixed_7a(), - [block8(0.20f0) for _ in 1:9]..., - block8(; activation = relu), - conv_bn((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) - return Chain(body, head) -end - -""" - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - -Creates an InceptionResNetv2 model. -([reference](https://arxiv.org/abs/1602.07261)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `InceptionResNetv2` does not currently support pretrained weights. -""" -struct InceptionResNetv2 - layers::Any -end - -function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, - nclasses = 1000) - layers = inceptionresnetv2(; inchannels, dropout, nclasses) - pretrain && loadpretrain!(layers, "InceptionResNetv2") - return InceptionResNetv2(layers) -end - -@functor InceptionResNetv2 - -(m::InceptionResNetv2)(x) = m.layers(x) - -backbone(m::InceptionResNetv2) = m.layers[1] -classifier(m::InceptionResNetv2) = m.layers[2] - -## Xception - -""" - xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, - grow_at_start = true) - -Create an Xception block. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `inchannels`: The number of channels in the input. The default value is 3. - - `outchannels`: number of output channels. - - `nrepeats`: number of repeats of depthwise separable convolution layers. - - `stride`: stride by which to downsample the input. - - `start_with_relu`: if true, start the block with a ReLU activation. - - `grow_at_start`: if true, increase the number of channels at the first convolution. -""" -function xception_block(inchannels, outchannels, nrepeats; stride = 1, - start_with_relu = true, - grow_at_start = true) - if outchannels != inchannels || stride != 1 - skip = conv_bn((1, 1), inchannels, outchannels, identity; stride = stride, - bias = false) - else - skip = [identity] - end - layers = [] - for i in 1:nrepeats - if grow_at_start - inc = i == 1 ? inchannels : outchannels - outc = outchannels - else - inc = inchannels - outc = i == nrepeats ? outchannels : inchannels - end - push!(layers, x -> relu.(x)) - append!(layers, - depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false, - use_bn = (false, false))) - push!(layers, BatchNorm(outc)) - end - layers = start_with_relu ? layers : layers[2:end] - push!(layers, MaxPool((3, 3); stride = stride, pad = 1)) - return Parallel(+, Chain(skip...), Chain(layers...)) -end - -""" - xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) - -Creates an Xception model. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. -""" -function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) - body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., - conv_bn((3, 3), 32, 64; bias = false)..., - xception_block(64, 128, 2; stride = 2, start_with_relu = false), - xception_block(128, 256, 2; stride = 2), - xception_block(256, 728, 2; stride = 2), - [xception_block(728, 728, 3) for _ in 1:8]..., - xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., - depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses)) - return Chain(body, head) -end - -struct Xception - layers::Any -end - -""" - Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - -Creates an Xception model. -([reference](https://arxiv.org/abs/1610.02357)) - -# Arguments - - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - - `inchannels`: The number of channels in the input. The default value is 3. - - `dropout`: rate of dropout in classifier head. - - `nclasses`: the number of output classes. - -!!! warning - - `Xception` does not currently support pretrained weights. -""" -function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - layers = xception(; inchannels, dropout, nclasses) - pretrain && loadpretrain!(layers, "xception") - return Xception(layers) -end - -@functor Xception - -(m::Xception)(x) = m.layers(x) - -backbone(m::Xception) = m.layers[1] -classifier(m::Xception) = m.layers[2] diff --git a/src/convnets/googlenet.jl b/src/convnets/inception/googlenet.jl similarity index 98% rename from src/convnets/googlenet.jl rename to src/convnets/inception/googlenet.jl index 318463494..8a88ca943 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/inception/googlenet.jl @@ -16,15 +16,12 @@ Create an inception module for use in GoogLeNet """ function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj) branch1 = Chain(Conv((1, 1), inplanes => out_1x1)) - branch2 = Chain(Conv((1, 1), inplanes => red_3x3), Conv((3, 3), red_3x3 => out_3x3; pad = 1)) - branch3 = Chain(Conv((1, 1), inplanes => red_5x5), Conv((5, 5), red_5x5 => out_5x5; pad = 2)) branch4 = Chain(MaxPool((3, 3); stride = 1, pad = 1), Conv((1, 1), inplanes => pool_proj)) - return Parallel(cat_channels, branch1, branch2, branch3, branch4) end @@ -83,15 +80,16 @@ See also [`googlenet`](#). struct GoogLeNet layers::Any end +@functor GoogLeNet function GoogLeNet(; pretrain = false, nclasses = 1000) layers = googlenet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "GoogLeNet") + if pretrain + loadpretrain!(layers, "GoogLeNet") + end return GoogLeNet(layers) end -@functor GoogLeNet - (m::GoogLeNet)(x) = m.layers(x) backbone(m::GoogLeNet) = m.layers[1] diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl new file mode 100644 index 000000000..4b4b78706 --- /dev/null +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -0,0 +1,133 @@ +function mixed_5b() + branch1 = Chain(conv_norm((1, 1), 192, 96)...) + branch2 = Chain(conv_norm((1, 1), 192, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3 = Chain(conv_norm((1, 1), 192, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), 192, 64)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function block35(scale = 1.0f0) + branch1 = Chain(conv_norm((1, 1), 320, 32)...) + branch2 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 32; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 320, 32)..., + conv_norm((3, 3), 32, 48; pad = 1)..., + conv_norm((3, 3), 48, 64; pad = 1)...) + branch4 = Chain(conv_norm((1, 1), 128, 320)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2, branch3), + branch4, inputscale(scale; activation = relu)), +) +end + +function mixed_6a() + branch1 = Chain(conv_norm((3, 3), 320, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 320, 256)..., + conv_norm((3, 3), 256, 256; pad = 1)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function block17(scale = 1.0f0) + branch1 = Chain(conv_norm((1, 1), 1088, 192)...) + branch2 = Chain(conv_norm((1, 1), 1088, 128)..., + conv_norm((1, 7), 128, 160; pad = (0, 3))..., + conv_norm((7, 1), 160, 192; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 384, 1088)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), + branch3, inputscale(scale; activation = relu)), +) +end + +function mixed_7a() + branch1 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; stride = 2)...) + branch3 = Chain(conv_norm((1, 1), 1088, 256)..., + conv_norm((3, 3), 256, 288; pad = 1)..., + conv_norm((3, 3), 288, 320; stride = 2)...) + branch4 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function block8(scale = 1.0f0; activation = identity) + branch1 = Chain(conv_norm((1, 1), 2080, 192)...) + branch2 = Chain(conv_norm((1, 1), 2080, 192)..., + conv_norm((1, 3), 192, 224; pad = (0, 1))..., + conv_norm((3, 1), 224, 256; pad = (1, 0))...) + branch3 = Chain(conv_norm((1, 1), 448, 2080)...) + return SkipConnection(Chain(Parallel(cat_channels, branch1, branch2), + branch3, inputscale(scale; activation)), +) +end + +""" + inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an InceptionResNetv2 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function inceptionresnetv2(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + MaxPool((3, 3); stride = 2), + conv_norm((3, 3), 64, 80)..., + conv_norm((3, 3), 80, 192)..., + MaxPool((3, 3); stride = 2), + mixed_5b(), + [block35(0.17f0) for _ in 1:10]..., + mixed_6a(), + [block17(0.10f0) for _ in 1:20]..., + mixed_7a(), + [block8(0.20f0) for _ in 1:9]..., + block8(; activation = relu), + conv_norm((1, 1), 2080, 1536)...) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(1536, nclasses)) + return Chain(body, head) +end + +""" + InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an InceptionResNetv2 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `InceptionResNetv2` does not currently support pretrained weights. +""" +struct InceptionResNetv2 + layers::Any +end +@functor InceptionResNetv2 + +function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "InceptionResNetv2") + end + return InceptionResNetv2(layers) +end + +(m::InceptionResNetv2)(x) = m.layers(x) + +backbone(m::InceptionResNetv2) = m.layers[1] +classifier(m::InceptionResNetv2) = m.layers[2] diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl new file mode 100644 index 000000000..68b283838 --- /dev/null +++ b/src/convnets/inception/inceptionv3.jl @@ -0,0 +1,196 @@ +## Inceptionv3 + +""" + inceptionv3_a(inplanes, pool_proj) + +Create an Inception-v3 style-A module +(ref: Fig. 5 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps + - `pool_proj`: the number of output feature maps for the pooling projection +""" +function inceptionv3_a(inplanes, pool_proj) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 64)) + branch5x5 = Chain(conv_norm((1, 1), inplanes, 48)..., + conv_norm((5, 5), 48, 64; pad = 2)...) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, pool_proj)...) + return Parallel(cat_channels, + branch1x1, branch5x5, branch3x3, branch_pool) +end + +""" + inceptionv3_b(inplanes) + +Create an Inception-v3 style-B module +(ref: Fig. 10 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_b(inplanes) + branch3x3_1 = Chain(conv_norm((3, 3), inplanes, 384; stride = 2)) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; stride = 2)...) + branch_pool = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, + branch3x3_1, branch3x3_2, branch_pool) +end + +""" + inceptionv3_c(inplanes, inner_planes, n = 7) + +Create an Inception-v3 style-C module +(ref: Fig. 6 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps + - `inner_planes`: the number of output feature maps within each branch + - `n`: the "grid size" (kernel size) for the convolution layers +""" +function inceptionv3_c(inplanes, inner_planes, n = 7) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 192)) + branch7x7_1 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, 192; pad = (3, 0))...) + branch7x7_2 = Chain(conv_norm((1, 1), inplanes, inner_planes)..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, inner_planes; pad = (0, 3))..., + conv_norm((n, 1), inner_planes, inner_planes; pad = (3, 0))..., + conv_norm((1, n), inner_planes, 192; pad = (0, 3))...) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, 192)...) + return Parallel(cat_channels, + branch1x1, branch7x7_1, branch7x7_2, branch_pool) +end + +""" + inceptionv3_d(inplanes) + +Create an Inception-v3 style-D module +(ref: [pytorch](https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/inception.py#L322)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_d(inplanes) + branch3x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((3, 3), 192, 320; stride = 2)...) + branch7x7x3 = Chain(conv_norm((1, 1), inplanes, 192)..., + conv_norm((1, 7), 192, 192; pad = (0, 3))..., + conv_norm((7, 1), 192, 192; pad = (3, 0))..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch_pool = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, + branch3x3, branch7x7x3, branch_pool) +end + +""" + inceptionv3_e(inplanes) + +Create an Inception-v3 style-E module +(ref: Fig. 7 in [paper](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `inplanes`: number of input feature maps +""" +function inceptionv3_e(inplanes) + branch1x1 = Chain(conv_norm((1, 1), inplanes, 320)) + branch3x3_1 = Chain(conv_norm((1, 1), inplanes, 384)) + branch3x3_1a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_1b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch3x3_2 = Chain(conv_norm((1, 1), inplanes, 448)..., + conv_norm((3, 3), 448, 384; pad = 1)...) + branch3x3_2a = Chain(conv_norm((1, 3), 384, 384; pad = (0, 1))) + branch3x3_2b = Chain(conv_norm((3, 1), 384, 384; pad = (1, 0))) + branch_pool = Chain(MeanPool((3, 3); pad = 1, stride = 1), + conv_norm((1, 1), inplanes, 192)...) + return Parallel(cat_channels, + branch1x1, + Chain(branch3x3_1, + Parallel(cat_channels, + branch3x3_1a, branch3x3_1b)), + Chain(branch3x3_2, + Parallel(cat_channels, + branch3x3_2a, branch3x3_2b)), + branch_pool) +end + +""" + inceptionv3(; nclasses = 1000) + +Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). + +# Arguments + + - `nclasses`: the number of output classes +""" +function inceptionv3(; nclasses = 1000) + layer = Chain(Chain(conv_norm((3, 3), 3, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + MaxPool((3, 3); stride = 2), + conv_norm((1, 1), 64, 80)..., + conv_norm((3, 3), 80, 192)..., + MaxPool((3, 3); stride = 2), + inceptionv3_a(192, 32), + inceptionv3_a(256, 64), + inceptionv3_a(288, 64), + inceptionv3_b(288), + inceptionv3_c(768, 128), + inceptionv3_c(768, 160), + inceptionv3_c(768, 160), + inceptionv3_c(768, 192), + inceptionv3_d(768), + inceptionv3_e(1280), + inceptionv3_e(2048)), + Chain(AdaptiveMeanPool((1, 1)), + Dropout(0.2), + MLUtils.flatten, + Dense(2048, nclasses))) + return layer +end + +""" + Inceptionv3(; pretrain = false, nclasses = 1000) + +Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). +See also [`inceptionv3`](#). + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +!!! warning + + `Inceptionv3` does not currently support pretrained weights. +""" +struct Inceptionv3 + layers::Any +end + +function Inceptionv3(; pretrain = false, nclasses = 1000) + layers = inceptionv3(; nclasses = nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv3") + end + return Inceptionv3(layers) +end + +@functor Inceptionv3 + +(m::Inceptionv3)(x) = m.layers(x) + +backbone(m::Inceptionv3) = m.layers[1] +classifier(m::Inceptionv3) = m.layers[2] diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl new file mode 100644 index 000000000..bb03646ec --- /dev/null +++ b/src/convnets/inception/inceptionv4.jl @@ -0,0 +1,158 @@ +function mixed_3a() + return Parallel(cat_channels, + MaxPool((3, 3); stride = 2), + Chain(conv_norm((3, 3), 64, 96; stride = 2)...)) +end + +function mixed_4a() + return Parallel(cat_channels, + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((3, 3), 64, 96)...), + Chain(conv_norm((1, 1), 160, 64)..., + conv_norm((1, 7), 64, 64; pad = (0, 3))..., + conv_norm((7, 1), 64, 64; pad = (3, 0))..., + conv_norm((3, 3), 64, 96)...)) +end + +function mixed_5a() + return Parallel(cat_channels, + Chain(conv_norm((3, 3), 192, 192; stride = 2)...), + MaxPool((3, 3); stride = 2)) +end + +function inceptionv4_a() + branch1 = Chain(conv_norm((1, 1), 384, 96)...) + branch2 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)...) + branch3 = Chain(conv_norm((1, 1), 384, 64)..., + conv_norm((3, 3), 64, 96; pad = 1)..., + conv_norm((3, 3), 96, 96; pad = 1)...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 384, 96)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_a() + branch1 = Chain(conv_norm((3, 3), 384, 384; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 384, 192)..., + conv_norm((3, 3), 192, 224; pad = 1)..., + conv_norm((3, 3), 224, 256; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_b() + branch1 = Chain(conv_norm((1, 1), 1024, 384)...) + branch2 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((1, 7), 192, 224; pad = (0, 3))..., + conv_norm((7, 1), 224, 256; pad = (3, 0))...) + branch3 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((7, 1), 192, 192; pad = (0, 3))..., + conv_norm((1, 7), 192, 224; pad = (3, 0))..., + conv_norm((7, 1), 224, 224; pad = (0, 3))..., + conv_norm((1, 7), 224, 256; pad = (3, 0))...) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1024, 128)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +function reduction_b() + branch1 = Chain(conv_norm((1, 1), 1024, 192)..., + conv_norm((3, 3), 192, 192; stride = 2)...) + branch2 = Chain(conv_norm((1, 1), 1024, 256)..., + conv_norm((1, 7), 256, 256; pad = (0, 3))..., + conv_norm((7, 1), 256, 320; pad = (3, 0))..., + conv_norm((3, 3), 320, 320; stride = 2)...) + branch3 = MaxPool((3, 3); stride = 2) + return Parallel(cat_channels, branch1, branch2, branch3) +end + +function inceptionv4_c() + branch1 = Chain(conv_norm((1, 1), 1536, 256)...) + branch2 = Chain(conv_norm((1, 1), 1536, 384)..., + Parallel(cat_channels, + Chain(conv_norm((1, 3), 384, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 384, 256; pad = (1, 0))...))) + branch3 = Chain(conv_norm((1, 1), 1536, 384)..., + conv_norm((3, 1), 384, 448; pad = (1, 0))..., + conv_norm((1, 3), 448, 512; pad = (0, 1))..., + Parallel(cat_channels, + Chain(conv_norm((1, 3), 512, 256; pad = (0, 1))...), + Chain(conv_norm((3, 1), 512, 256; pad = (1, 0))...))) + branch4 = Chain(MeanPool((3, 3); stride = 1, pad = 1), conv_norm((1, 1), 1536, 256)...) + return Parallel(cat_channels, branch1, branch2, branch3, branch4) +end + +""" + inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Create an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function inceptionv4(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., + conv_norm((3, 3), 32, 32)..., + conv_norm((3, 3), 32, 64; pad = 1)..., + mixed_3a(), + mixed_4a(), + mixed_5a(), + inceptionv4_a(), + inceptionv4_a(), + inceptionv4_a(), + inceptionv4_a(), + reduction_a(), # mixed_6a + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + inceptionv4_b(), + reduction_b(), # mixed_7a + inceptionv4_c(), + inceptionv4_c(), + inceptionv4_c()) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(1536, nclasses)) + return Chain(body, head) +end + +""" + Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Inceptionv4 model. +([reference](https://arxiv.org/abs/1602.07261)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `Inceptionv4` does not currently support pretrained weights. +""" +struct Inceptionv4 + layers::Any +end +@functor Inceptionv4 + +function Inceptionv4(; pretrain = false, inchannels = 3, dropout_rate = 0.0, + nclasses = 1000) + layers = inceptionv4(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "Inceptionv4") + end + return Inceptionv4(layers) +end + +(m::Inceptionv4)(x) = m.layers(x) + +backbone(m::Inceptionv4) = m.layers[1] +classifier(m::Inceptionv4) = m.layers[2] diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl new file mode 100644 index 000000000..3c6d8331a --- /dev/null +++ b/src/convnets/inception/xception.jl @@ -0,0 +1,106 @@ +""" + xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true, + grow_at_start = true) + +Create an Xception block. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `inchannels`: The number of channels in the input. + - `outchannels`: number of output channels. + - `nrepeats`: number of repeats of depthwise separable convolution layers. + - `stride`: stride by which to downsample the input. + - `start_with_relu`: if true, start the block with a ReLU activation. + - `grow_at_start`: if true, increase the number of channels at the first convolution. +""" +function xception_block(inchannels, outchannels, nrepeats; stride = 1, + start_with_relu = true, + grow_at_start = true) + if outchannels != inchannels || stride != 1 + skip = conv_norm((1, 1), inchannels, outchannels, identity; stride = stride, + bias = false) + else + skip = [identity] + end + layers = [] + for i in 1:nrepeats + if grow_at_start + inc = i == 1 ? inchannels : outchannels + outc = outchannels + else + inc = inchannels + outc = i == nrepeats ? outchannels : inchannels + end + push!(layers, relu) + append!(layers, + depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, + use_norm = (false, false))) + push!(layers, BatchNorm(outc)) + end + layers = start_with_relu ? layers : layers[2:end] + push!(layers, MaxPool((3, 3); stride = stride, pad = 1)) + return Parallel(+, Chain(skip...), Chain(layers...)) +end + +""" + xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Xception model. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. +""" +function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + body = Chain(conv_norm((3, 3), inchannels, 32; stride = 2, bias = false)..., + conv_norm((3, 3), 32, 64; bias = false)..., + xception_block(64, 128, 2; stride = 2, start_with_relu = false), + xception_block(128, 256, 2; stride = 2), + xception_block(256, 728, 2; stride = 2), + [xception_block(728, 728, 3) for _ in 1:8]..., + xception_block(728, 1024, 2; stride = 2, grow_at_start = false), + depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)..., + depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate), + Dense(2048, nclasses)) + return Chain(body, head) +end + +struct Xception + layers::Any +end +@functor Xception + +""" + Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + +Creates an Xception model. +([reference](https://arxiv.org/abs/1610.02357)) + +# Arguments + + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. + - `inchannels`: number of input channels. + - `dropout_rate`: rate of dropout in classifier head. + - `nclasses`: the number of output classes. + +!!! warning + + `Xception` does not currently support pretrained weights. +""" +function Xception(; pretrain = false, inchannels = 3, dropout_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, dropout_rate, nclasses) + if pretrain + loadpretrain!(layers, "xception") + end + return Xception(layers) +end + +(m::Xception)(x) = m.layers(x) + +backbone(m::Xception) = m.layers[1] +classifier(m::Xception) = m.layers[2] diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl deleted file mode 100644 index 8de98e886..000000000 --- a/src/convnets/mobilenet.jl +++ /dev/null @@ -1,327 +0,0 @@ -# MobileNetv1 - -""" - mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - nclasses = 1000) - -Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `dw`: Set true to use a depthwise separable convolution or false for regular convolution - + `o`: The number of output feature maps - + `s`: The stride of the convolutional kernel - + `r`: The number of time this configuration block is repeated - - `activate`: The activation function to use throughout the network - - `inchannels`: The number of input channels. The default value is 3. - - `nclasses`: The number of output classes -""" -function mobilenetv1(width_mult, config; - activation = relu, - inchannels = 3, - nclasses = 1000) - layers = [] - for (dw, outch, stride, nrepeats) in config - outch = Int(outch * width_mult) - for _ in 1:nrepeats - layer = dw ? - depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : - conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1, - bias = false) - append!(layers, layer) - inchannels = outch - end - end - - return Chain(Chain(layers), - Chain(GlobalMeanPool(), - MLUtils.flatten, - Dense(inchannels, nclasses))) -end - -const mobilenetv1_configs = [ - # dw, c, s, r - (false, 32, 2, 1), - (true, 64, 1, 1), - (true, 128, 2, 1), - (true, 128, 1, 1), - (true, 256, 2, 1), - (true, 256, 1, 1), - (true, 512, 2, 1), - (true, 512, 1, 5), - (true, 1024, 2, 1), - (true, 1024, 1, 1), -] - -""" - MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv1 model with the baseline configuration -([reference](https://arxiv.org/abs/1704.04861v1)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv1`](#). -""" -struct MobileNetv1 - layers::Any -end - -function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv1")) - return MobileNetv1(layers) -end - -@functor MobileNetv1 - -(m::MobileNetv1)(x) = m.layers(x) - -backbone(m::MobileNetv1) = m.layers[1] -classifier(m::MobileNetv1) = m.layers[2] - -# MobileNetv2 - -""" - mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - -Create a MobileNetv2 model. -([reference](https://arxiv.org/abs/1801.04381)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer - + `c`: The number of output feature maps - + `n`: The number of times a block is repeated - + `s`: The stride of the convolutional kernel - + `a`: The activation function used in the bottleneck layer - - `inchannels`: The number of input channels. The default value is 3. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: The number of output classes -""" -function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) - # building first layer - inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) - layers = [] - append!(layers, conv_bn((3, 3), inchannels, inplanes; pad = 1, stride = 2)) - # building inverted residual blocks - for (t, c, n, s, a) in configs - outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) - for i in 1:n - push!(layers, - invertedresidual(3, inplanes, inplanes * t, outplanes, a; - stride = i == 1 ? s : 1)) - inplanes = outplanes - end - end - # building last several layers - outplanes = (width_mult > 1) ? - _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : - max_width - return Chain(Chain(Chain(layers), - conv_bn((1, 1), inplanes, outplanes, relu6; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(outplanes, nclasses))) -end - -# Layer configurations for MobileNetv2 -const mobilenetv2_configs = [ - # t, c, n, s, a - (1, 16, 1, 1, relu6), - (6, 24, 2, 2, relu6), - (6, 32, 3, 2, relu6), - (6, 64, 4, 2, relu6), - (6, 96, 3, 1, relu6), - (6, 160, 3, 2, relu6), - (6, 320, 1, 1, relu6), -] - -# Model definition for MobileNetv2 -struct MobileNetv2 - layers::Any -end - -""" - MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv2 model with the specified configuration. -([reference](https://arxiv.org/abs/1801.04381)). -Set `pretrain` to `true` to load the pretrained weights for ImageNet. - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of input channels. The default value is 3. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `nclasses`: The number of output classes - -See also [`Metalhead.mobilenetv2`](#). -""" -function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, - nclasses = 1000) - layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv2")) - return MobileNetv2(layers) -end - -@functor MobileNetv2 - -(m::MobileNetv2)(x) = m.layers(x) - -backbone(m::MobileNetv2) = m.layers[1] -classifier(m::MobileNetv2) = m.layers[2] - -# MobileNetv3 - -""" - mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - -Create a MobileNetv3 model. -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - - `configs`: a "list of tuples" configuration for each layer that details: - - + `k::Integer` - The size of the convolutional kernel - + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - + `t::Integer` - The number of output feature maps for a given block - + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - + `s::Integer` - The stride of the convolutional kernel - + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - `inchannels`: The number of input channels. The default value is 3. - - `max_width`: The maximum number of feature maps in any layer of the network - - `nclasses`: the number of output classes -""" -function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) - # building first layer - inplanes = _round_channels(16 * width_mult, 8) - layers = [] - append!(layers, - conv_bn((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, - bias = false)) - explanes = 0 - # building inverted residual blocks - for (k, t, c, r, a, s) in configs - # inverted residual layers - outplanes = _round_channels(c * width_mult, 8) - explanes = _round_channels(inplanes * t, 8) - push!(layers, - invertedresidual(k, inplanes, explanes, outplanes, a; - stride = s, reduction = r)) - inplanes = outplanes - end - # building last several layers - output_channel = max_width - output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : - output_channel - classifier = Chain(Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses)) - return Chain(Chain(Chain(layers), - conv_bn((1, 1), inplanes, explanes, hardswish; bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) -end - -# Configurations for small and large mode for MobileNetv3 -mobilenetv3_configs = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) - -# Model definition for MobileNetv3 -struct MobileNetv3 - layers::Any -end - -""" - MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - -Create a MobileNetv3 model with the specified configuration. -([reference](https://arxiv.org/abs/1905.02244)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -# Arguments - - - `mode`: :small or :large for the size of the model (see paper). - - `width_mult`: Controls the number of output feature maps in each block - (with 1.0 being the default in the paper; - this is usually a value between 0.1 and 1.4) - - `inchannels`: The number of channels in the input. The default value is 3. - - `pretrain`: whether to load the pre-trained weights for ImageNet - - `nclasses`: the number of output classes - -See also [`Metalhead.mobilenetv3`](#). -""" -function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, - pretrain = false, nclasses = 1000) - @assert mode in [:large, :small] "`mode` has to be either :large or :small" - max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, - nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv3", mode)) - return MobileNetv3(layers) -end - -@functor MobileNetv3 - -(m::MobileNetv3)(x) = m.layers(x) - -backbone(m::MobileNetv3) = m.layers[1] -classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl new file mode 100644 index 000000000..fffa93a4d --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -0,0 +1,98 @@ +""" + mobilenetv1(width_mult, config; + activation = relu, + inchannels = 3, + nclasses = 1000) + +Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `dw`: Set true to use a depthwise separable convolution or false for regular convolution + + `o`: The number of output feature maps + + `s`: The stride of the convolutional kernel + + `r`: The number of time this configuration block is repeated + - `activate`: The activation function to use throughout the network + - `inchannels`: The number of input channels. The default value is 3. + - `nclasses`: The number of output classes +""" +function mobilenetv1(width_mult, config; + activation = relu, + inchannels = 3, + nclasses = 1000) + layers = [] + for (dw, outch, stride, nrepeats) in config + outch = Int(outch * width_mult) + for _ in 1:nrepeats + layer = dw ? + depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; + stride = stride, pad = 1, bias = false) : + conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, + bias = false) + append!(layers, layer) + inchannels = outch + end + end + + return Chain(Chain(layers), + Chain(GlobalMeanPool(), + MLUtils.flatten, + Dense(inchannels, nclasses))) +end + +# Layer configurations for MobileNetv1 +const MOBILENETV1_CONFIGS = [ + # dw, c, s, r + (false, 32, 2, 1), + (true, 64, 1, 1), + (true, 128, 2, 1), + (true, 128, 1, 1), + (true, 256, 2, 1), + (true, 256, 1, 1), + (true, 512, 2, 1), + (true, 512, 1, 5), + (true, 1024, 2, 1), + (true, 1024, 1, 1), +] + +""" + MobileNetv1(width_mult = 1; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv1 model with the baseline configuration +([reference](https://arxiv.org/abs/1704.04861v1)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of input channels. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv1`](#). +""" +struct MobileNetv1 + layers::Any +end +@functor MobileNetv1 + +function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, + nclasses = 1000) + layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv1")) + end + return MobileNetv1(layers) +end + +(m::MobileNetv1)(x) = m.layers(x) + +backbone(m::MobileNetv1) = m.layers[1] +classifier(m::MobileNetv1) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl new file mode 100644 index 000000000..a97e7dda1 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -0,0 +1,96 @@ +""" + mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) + +Create a MobileNetv2 model. +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) + + - `configs`: A "list of tuples" configuration for each layer that details: + + + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer + + `c`: The number of output feature maps + + `n`: The number of times a block is repeated + + `s`: The stride of the convolutional kernel + + `a`: The activation function used in the bottleneck layer + - `inchannels`: The number of input channels. + - `max_width`: The maximum number of feature maps in any layer of the network + - `nclasses`: The number of output classes +""" +function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, nclasses = 1000) + # building first layer + inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) + layers = [] + append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) + # building inverted residual blocks + for (t, c, n, s, a) in configs + outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) + for i in 1:n + push!(layers, + invertedresidual(3, inplanes, inplanes * t, outplanes, a; + stride = i == 1 ? s : 1)) + inplanes = outplanes + end + end + # building last several layers + outplanes = (width_mult > 1) ? + _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : + max_width + return Chain(Chain(Chain(layers), + conv_norm((1, 1), inplanes, outplanes, relu6; bias = false)...), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, + Dense(outplanes, nclasses))) +end + +# Layer configurations for MobileNetv2 +const MOBILENETV2_CONFIGS = [ + # t, c, n, s, a + (1, 16, 1, 1, relu6), + (6, 24, 2, 2, relu6), + (6, 32, 3, 2, relu6), + (6, 64, 4, 2, relu6), + (6, 96, 3, 1, relu6), + (6, 160, 3, 2, relu6), + (6, 320, 1, 1, relu6), +] + +struct MobileNetv2 + layers::Any +end +@functor MobileNetv2 + +""" + MobileNetv2(width_mult = 1.0; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv2 model with the specified configuration. +([reference](https://arxiv.org/abs/1801.04381)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of input channels. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv2`](#). +""" +function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, + nclasses = 1000) + layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) + pretrain && loadpretrain!(layers, string("MobileNetv2")) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end + return MobileNetv2(layers) +end + +(m::MobileNetv2)(x) = m.layers(x) + +backbone(m::MobileNetv2) = m.layers[1] +classifier(m::MobileNetv2) = m.layers[2] diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl new file mode 100644 index 000000000..d8666c5f3 --- /dev/null +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -0,0 +1,128 @@ +""" + mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) + +Create a MobileNetv3 model. +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + + - `configs`: a "list of tuples" configuration for each layer that details: + + + `k::Integer` - The size of the convolutional kernel + + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer + + `t::Integer` - The number of output feature maps for a given block + + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers + + `s::Integer` - The stride of the convolutional kernel + + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) + - `inchannels`: The number of input channels. + - `max_width`: The maximum number of feature maps in any layer of the network + - `nclasses`: the number of output classes +""" +function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, nclasses = 1000) + # building first layer + inplanes = _round_channels(16 * width_mult, 8) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, hardswish; pad = 1, stride = 2, + bias = false)) + explanes = 0 + # building inverted residual blocks + for (k, t, c, r, a, s) in configs + # inverted residual layers + outplanes = _round_channels(c * width_mult, 8) + explanes = _round_channels(inplanes * t, 8) + push!(layers, + invertedresidual(k, inplanes, explanes, outplanes, a; + stride = s, reduction = r)) + inplanes = outplanes + end + # building last several layers + output_channel = max_width + output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : + output_channel + classifier = Chain(Dense(explanes, output_channel, hardswish), + Dropout(0.2), + Dense(output_channel, nclasses)) + return Chain(Chain(Chain(layers), + conv_norm((1, 1), inplanes, explanes, hardswish; bias = false)...), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) +end + +# Layer configurations for small and large models for MobileNetv3 +const MOBILENETV3_CONFIGS = Dict(:small => [ + # k, t, c, SE, a, s + (3, 1, 16, 4, relu, 2), + (3, 4.5, 24, nothing, relu, 2), + (3, 3.67, 24, nothing, relu, 1), + (5, 4, 40, 4, hardswish, 2), + (5, 6, 40, 4, hardswish, 1), + (5, 6, 40, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 2), + (5, 6, 96, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 1), + ], + :large => [ + # k, t, c, SE, a, s + (3, 1, 16, nothing, relu, 1), + (3, 4, 24, nothing, relu, 2), + (3, 3, 24, nothing, relu, 1), + (5, 3, 40, 4, relu, 2), + (5, 3, 40, 4, relu, 1), + (5, 3, 40, 4, relu, 1), + (3, 6, 80, nothing, hardswish, 2), + (3, 2.5, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 2), + (5, 6, 160, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 1), + ]) + +struct MobileNetv3 + layers::Any +end +@functor MobileNetv3 + +""" + MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) + +Create a MobileNetv3 model with the specified configuration. +([reference](https://arxiv.org/abs/1905.02244)). +Set `pretrain = true` to load the model with pre-trained weights for ImageNet. + +# Arguments + + - `mode`: :small or :large for the size of the model (see paper). + - `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) + - `inchannels`: The number of channels in the input. + - `pretrain`: whether to load the pre-trained weights for ImageNet + - `nclasses`: the number of output classes + +See also [`Metalhead.mobilenetv3`](#). +""" +function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = 3, + pretrain = false, nclasses = 1000) + @assert mode in [:large, :small] "`mode` has to be either :large or :small" + max_width = (mode == :large) ? 1280 : 1024 + layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width, + nclasses) + if pretrain + loadpretrain!(layers, string("MobileNetv3", mode)) + end + return MobileNetv3(layers) +end + +(m::MobileNetv3)(x) = m.layers(x) + +backbone(m::MobileNetv3) = m.layers[1] +classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl deleted file mode 100644 index 53d1fd6e3..000000000 --- a/src/convnets/resnet.jl +++ /dev/null @@ -1,259 +0,0 @@ -""" - basicblock(inplanes, outplanes, downsample = false) - -Create a basic residual block -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input -""" -function basicblock(inplanes, outplanes, downsample = false) - stride = downsample ? 2 : 1 - return Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, - bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, - bias = false)...) -end - -""" - bottleneck(inplanes, outplanes, downsample = false; stride = [1, (downsample ? 2 : 1), 1]) - -Create a bottleneck residual block -([reference](https://arxiv.org/abs/1512.03385v1)). The bottleneck is composed of -3 convolutional layers each with the given `stride`. -By default, `stride` implements ["ResNet v1.5"](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch) -which uses `stride == [1, 2, 1]` when `downsample == true`. -This version is standard across various ML frameworks. -The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead. - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input - - `stride`: a list of the stride of the 3 convolutional layers -""" -function bottleneck(inplanes, outplanes, downsample = false; - stride = [1, (downsample ? 2 : 1), 1]) - return Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], - bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, - bias = false)..., - conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], - bias = false)...) -end - -""" - bottleneck_v1(inplanes, outplanes, downsample = false) - -Create a bottleneck residual block -([reference](https://arxiv.org/abs/1512.03385v1)). The bottleneck is composed of -3 convolutional layers with all a stride of 1 except the first convolutional -layer which has a stride of 2. - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: a list of the number of output feature maps for each convolution - within the residual block - - `downsample`: set to `true` to downsample the input -""" -function bottleneck_v1(inplanes, outplanes, downsample = false) - return bottleneck(inplanes, outplanes, downsample; - stride = [(downsample ? 2 : 1), 1, 1]) -end - -""" - resnet(block, residuals::NTuple{2, Any}, connection = addrelu; - channel_config, block_config, nclasses = 1000) - -Create a ResNet model -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - `residuals`: a 2-tuple of functions with input `(inplanes, outplanes, downsample=false)`, - each of which will return a function that will be used as a new "skip" path to match a residual block. - [`Metalhead.skip_identity`](#) and [`Metalhead.skip_projection`](#) can be used here. - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection = addrelu; - channel_config, block_config, nclasses = 1000) - inplanes = 64 - baseplanes = 64 - layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false)) - push!(layers, MaxPool((3, 3); stride = (2, 2), pad = (1, 1))) - for (i, nrepeats) in enumerate(block_config) - # output planes within a block - outplanes = baseplanes .* channel_config - # push first skip connection on using first residual - # downsample the residual path if this is the first repetition of a block - push!(layers, - Parallel(connection, block(inplanes, outplanes, i != 1), - residuals[i][1](inplanes, outplanes[end], i != 1))) - # push remaining skip connections on using second residual - inplanes = outplanes[end] - for _ in 2:nrepeats - push!(layers, - Parallel(connection, block(inplanes, outplanes, false), - residuals[i][2](inplanes, outplanes[end], false))) - inplanes = outplanes[end] - end - # next set of output plane base is doubled - baseplanes *= 2 - end - # next set of output plane base is doubled - baseplanes *= 2 - return Chain(Chain(layers), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(inplanes, nclasses))) -end - -""" - resnet(block, shortcut_config::Symbol, connection = addrelu; - channel_config, block_config, nclasses = 1000) - -Create a ResNet model -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments - - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - - `shortcut_config`: the type of shortcut style (either `:A`, `:B`, or `:C`) - - + `:A`: uses a [`Metalhead.skip_identity`](#) for all residual blocks - + `:B`: uses a [`Metalhead.skip_projection`](#) for the first residual block - and [`Metalhead.skip_identity`](@) for the remaining residual blocks - + `:C`: uses a [`Metalhead.skip_projection`](#) for all residual blocks - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnet(block, shortcut_config::AbstractVector{<:Symbol}, args...; kwargs...) - shortcut_dict = Dict(:A => (skip_identity, skip_identity), - :B => (skip_projection, skip_identity), - :C => (skip_projection, skip_projection)) - if any(sc -> !haskey(shortcut_dict, sc), shortcut_config) - error("Unrecognized shortcut_config ($shortcut_config) passed to `resnet` (use only :A, :B, or :C).") - end - shortcut = [shortcut_dict[sc] for sc in shortcut_config] - return resnet(block, shortcut, args...; kwargs...) -end - -function resnet(block, shortcut_config::Symbol, args...; block_config, kwargs...) - return resnet(block, fill(shortcut_config, length(block_config)), args...; - block_config = block_config, kwargs...) -end - -function resnet(block, residuals::NTuple{2}, args...; kwargs...) - return resnet(block, [residuals], args...; kwargs...) -end - -const resnet_config = Dict(18 => (([1, 1], [2, 2, 2, 2], [:A, :B, :B, :B]), basicblock), - 34 => (([1, 1], [3, 4, 6, 3], [:A, :B, :B, :B]), basicblock), - 50 => (([1, 1, 4], [3, 4, 6, 3], [:B, :B, :B, :B]), bottleneck), - 101 => (([1, 1, 4], [3, 4, 23, 3], [:B, :B, :B, :B]), bottleneck), - 152 => (([1, 1, 4], [3, 8, 36, 3], [:B, :B, :B, :B]), bottleneck)) - -""" - ResNet(channel_config, block_config, shortcut_config; - block, connection = addrelu, nclasses = 1000) - -Create a `ResNet` model -([reference](https://arxiv.org/abs/1512.03385v1)). -See also [`resnet`](#). - -# Arguments - - - `channel_config`: the growth rate of the output feature maps within a residual block - - `block_config`: a list of the number of residual blocks at each stage - - `shortcut_config`: the type of shortcut style (either `:A`, `:B`, or `:C`). - `shortcut_config` can also be a vector of symbols if different shortcut styles are applied to - different residual blocks. - - `block`: a function with input `(inplanes, outplanes, downsample=false)` that returns - a new residual block (see [`Metalhead.basicblock`](#) and [`Metalhead.bottleneck`](#)) - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `nclasses`: the number of output classes -""" -struct ResNet - layers::Any -end - -function ResNet(channel_config, block_config, shortcut_config; - block, connection = addrelu, nclasses = 1000) - layers = resnet(block, - shortcut_config, - connection; - channel_config = channel_config, - block_config = block_config, - nclasses = nclasses) - return ResNet(layers) -end - -@functor ResNet - -(m::ResNet)(x) = m.layers(x) - -backbone(m::ResNet) = m.layers[1] -classifier(m::ResNet) = m.layers[2] - -""" - ResNet(depth = 50; pretrain = false, nclasses = 1000) - -Create a ResNet model with a specified depth -([reference](https://arxiv.org/abs/1512.03385v1)) -following [these modification](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch) -referred as ResNet v1.5. - -See also [`Metalhead.resnet`](#). - -# Arguments - - - `depth`: depth of the ResNet model. Options include (18, 34, 50, 101, 152). - - `nclasses`: the number of output classes - -For `ResNet(18)` and `ResNet(34)`, the parameter-free shortcut style (type `:A`) -is used in the first block and the three other blocks use type `:B` connection -(following the implementation in PyTorch). The published version of -`ResNet(18)` and `ResNet(34)` used type `:A` shortcuts for all four blocks. The -example below shows how to create a 18 or 34-layer `ResNet` using only type `:A` -shortcuts: - -```julia -using Metalhead - -resnet18 = ResNet([1, 1], [2, 2, 2, 2], :A; block = Metalhead.basicblock) - -resnet34 = ResNet([1, 1], [3, 4, 6, 3], :A; block = Metalhead.basicblock) -``` - -The bottleneck of the orginal ResNet model has a stride of 2 on the first -convolutional layer when downsampling (instead of the second convolutional layers -as in ResNet v1.5). The architecture of the orignal ResNet model can be obtained -as shown below: - -```julia -resnet50_v1 = ResNet([1, 1, 4], [3, 4, 6, 3], :B; block = Metalhead.bottleneck_v1) -``` -""" -function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000) - @assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))" - config, block = resnet_config[depth] - model = ResNet(config...; block = block, nclasses = nclasses) - pretrain && loadpretrain!(model, string("resnet", depth)) - return model -end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl new file mode 100644 index 000000000..329663c13 --- /dev/null +++ b/src/convnets/resnets/core.jl @@ -0,0 +1,338 @@ +""" + basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, revnorm = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + +Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). + +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `reduction_factor`: the factor by which the input feature maps + are reduced before the first convolution. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer + - `drop_path`: the drop path layer + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. +""" +function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + first_planes = planes ÷ reduction_factor + outplanes = planes + conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, + stride, pad = 1, bias = false) + conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, + pad = 1, bias = false) + layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), + drop_path] + return Chain(filter!(!=(identity), layers)...) +end + +""" + bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, + reduction_factor = 1, activation = relu, + norm_layer = BatchNorm, revnorm = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + +Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). + +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `cardinality`: the number of groups in the convolution. + - `base_width`: the number of output feature maps for each convolutional group. + - `reduction_factor`: the factor by which the input feature maps are reduced before the first + convolution. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `drop_block`: the drop block layer + - `drop_path`: the drop path layer + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. +""" +function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, + cardinality::Integer = 1, base_width::Integer = 64, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) + width = floor(Int, planes * (base_width / 64)) * cardinality + first_planes = width ÷ reduction_factor + outplanes = planes * 4 + conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, + bias = false) + conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, + stride, pad = 1, groups = cardinality, bias = false) + conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm, + bias = false) + layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., + attn_fn(outplanes), drop_path] + return Chain(filter!(!=(identity), layers)...) +end + +# Downsample layer using convolutions. +function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm, revnorm = false) + return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, + pad = SamePad(), stride, bias = false)...) +end + +# Downsample layer using max pooling +function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm, revnorm = false) + pool = (stride == 1) ? identity : MeanPool((2, 2); stride, pad = SamePad()) + return Chain(pool, + conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, + bias = false)...) +end + +# Downsample layer which is an identity projection. Uses max pooling +# when the output size is more than the input size. +# TODO - figure out how to make this work when outplanes < inplanes +function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) + if outplanes > inplanes + return Chain(MaxPool((1, 1); stride = 2), + y -> cat_channels(y, + zeros(eltype(y), + size(y, 1), + size(y, 2), + outplanes - inplanes, size(y, 4)))) + else + return identity + end +end + +# Shortcut configurations for the ResNet models +const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), + :B => (downsample_conv, downsample_identity), + :C => (downsample_conv, downsample_conv), + :D => (downsample_pool, downsample_identity)) + +# Stride for each block in the ResNet model +function resnet_stride(stage_idx::Integer, block_idx::Integer) + return (stage_idx == 1 || block_idx != 1) ? 1 : 2 +end + +# returns `DropBlock`s for each stage of the ResNet as in timm. +# TODO - add experimental options for DropBlock as part of the API (#188) +# function _drop_blocks(drop_block_rate::AbstractFloat) +# return [ +# identity, identity, +# DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00), +# ] +# end + +""" + resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false, + norm_layer = BatchNorm, activation = relu) + +Builds a stem to be used in a ResNet model. See the `stem` argument of [`resnet`](#) for details +on how to use this function. + +# Arguments + + - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. + + + `:default`: Builds a stem based on the default ResNet stem, which consists of a single + 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling + layer with stride 2. + + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using + a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each + one. This is followed by a 3x3 max pooling layer with stride 2. + + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second + convolution. This is an experimental variant from the `timm` library in Python that + shows peformance improvements over the `:deep` stem in some cases. + + - `inchannels`: The number of channels in the input. + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. + - `norm_layer`: The normalisation layer used in the stem. + - `activation`: The activation function used in the stem. +""" +function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, + replace_pool::Bool = false, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false) + @assert stem_type in [:default, :deep, :deep_tiered] + "Stem type must be one of [:default, :deep, :deep_tiered]" + # Main stem + deep_stem = stem_type == :deep || stem_type == :deep_tiered + inplanes = deep_stem ? stem_width * 2 : 64 + # Deep stem that uses three successive 3x3 convolutions instead of a single 7x7 convolution + if deep_stem + if stem_type == :deep + stem_channels = (stem_width, stem_width) + elseif stem_type == :deep_tiered + stem_channels = (3 * (stem_width ÷ 4), stem_width) + end + conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; + norm_layer, revnorm, stride = 2, pad = 1, bias = false)..., + conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; + norm_layer, pad = 1, bias = false)..., + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) + else + conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) + end + bn1 = norm_layer(inplanes, activation) + # Stem pooling + stempool = replace_pool ? + Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, + revnorm, + stride = 2, pad = 1, bias = false)...) : + MaxPool((3, 3); stride = 2, pad = 1) + return Chain(conv1, bn1, stempool) +end + +function resnet_planes(block_repeats::Vector{<:Integer}) + return Iterators.flatten((64 * 2^(stage_idx - 1) for _ in 1:stages) + for (stage_idx, stages) in enumerate(block_repeats)) +end + +function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, + reduction_factor::Integer = 1, expansion::Integer = 1, + norm_layer = BatchNorm, revnorm::Bool = false, + activation = relu, attn_fn = planes -> identity, + drop_block_rate = 0.0, drop_path_rate = 0.0, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_tuple[1] : downsample_tuple[2] + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = basicblock(inplanes, planes; stride, reduction_factor, activation, + norm_layer, revnorm, attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end + +function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, + cardinality::Integer = 1, base_width::Integer = 64, + reduction_factor::Integer = 1, expansion::Integer = 4, + norm_layer = BatchNorm, revnorm::Bool = false, + activation = relu, attn_fn = planes -> identity, + drop_block_rate = 0.0, drop_path_rate = 0.0, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) + blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # DropBlock, DropPath both take in rates based on a linear scaling schedule + # This is also needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_tuple[1] : downsample_tuple[2] + # DropBlock, DropPath both take in rates based on a linear scaling schedule + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + drop_path = DropPath(pathschedule[schedule_idx]) + drop_block = DropBlock(blockschedule[schedule_idx]) + block = bottleneck(inplanes, planes; stride, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_path, drop_block) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end + +function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) + # Construct each stage + stages = [] + for (stage_idx, num_blocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...) + for block_idx in 1:num_blocks] + push!(stages, Chain(blocks...)) + end + return Chain(stages...) +end + +function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection, + classifier_fn) + # Build stages of the ResNet + stage_blocks = resnet_stages(get_layers, block_repeats, connection) + backbone = Chain(stem, stage_blocks) + # Build the classifier head + nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] + classifier = classifier_fn(nfeaturemaps) + return Chain(backbone, classifier) +end + +function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; + downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity), + cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, + reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), + inchannels::Integer = 3, stem_fn = resnet_stem, + connection = addact, activation = relu, norm_layer = BatchNorm, + revnorm::Bool = false, attn_fn = planes -> identity, + pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false, + drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0, + nclasses::Integer = 1000) + # Build stem + stem = stem_fn(; inchannels) + # Block builder + if block_type == :basicblock + @assert cardinality==1 "Cardinality must be 1 for `basicblock`" + @assert base_width==64 "Base width must be 64 for `basicblock`" + get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, revnorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt) + elseif block_type == :bottleneck + get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, + reduction_factor, activation, norm_layer, + revnorm, attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt) + else + # TODO: write better message when we have link to dev docs for resnet + throw(ArgumentError("Unknown block type $block_type")) + end + classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, + pool_layer, use_conv) + return resnet((imsize..., inchannels), stem, get_layers, block_repeats, + connection$activation, classifier_fn) +end +function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) + return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...) +end + +# block-layer configurations for ResNet-like models +const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]), + 34 => (:basicblock, [3, 4, 6, 3]), + 50 => (:bottleneck, [3, 4, 6, 3]), + 101 => (:bottleneck, [3, 4, 23, 3]), + 152 => (:bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl new file mode 100644 index 000000000..fac7e7415 --- /dev/null +++ b/src/convnets/resnets/resnet.jl @@ -0,0 +1,77 @@ +""" + ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a ResNet model with the specified depth. +((reference)[https://arxiv.org/abs/1512.03385]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `ResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct ResNet + layers::Any +end +@functor ResNet + +function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + _checkconfig(depth, keys(RESNET_CONFIGS)) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("ResNet", depth)) + end + return ResNet(layers) +end + +(m::ResNet)(x) = m.layers(x) + +backbone(m::ResNet) = m.layers[1] +classifier(m::ResNet) = m.layers[2] + +""" + WideResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a Wide ResNet model with the specified depth. The model is the same as ResNet +except for the bottleneck number of channels which is twice larger in every block. +The number of channels in outer 1x1 convolutions is the same. +((reference)[https://arxiv.org/abs/1605.07146]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the Wide ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: The number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `WideResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct WideResNet + layers::Any +end +@functor WideResNet + +function WideResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + _checkconfig(depth, [50, 101]) + layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("WideResNet", depth)) + end + return WideResNet(layers) +end + +(m::WideResNet)(x) = m.layers(x) + +backbone(m::WideResNet) = m.layers[1] +classifier(m::WideResNet) = m.layers[2] diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl new file mode 100644 index 000000000..8032df5ab --- /dev/null +++ b/src/convnets/resnets/resnext.jl @@ -0,0 +1,41 @@ +""" + ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) + +Creates a ResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/abs/1611.05431]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `ResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct ResNeXt + layers::Any +end +@functor ResNeXt + +(m::ResNeXt)(x) = m.layers(x) + +function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, + base_width = 4, inchannels = 3, nclasses = 1000) + _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) + if pretrain + loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) + end + return ResNeXt(layers) +end + +backbone(m::ResNeXt) = m.layers[1] +classifier(m::ResNeXt) = m.layers[2] diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl new file mode 100644 index 000000000..05d842173 --- /dev/null +++ b/src/convnets/resnets/seresnet.jl @@ -0,0 +1,81 @@ +""" + SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + +Creates a SEResNet model with the specified depth. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes + +!!! warning + + `SEResNet` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct SEResNet + layers::Any +end +@functor SEResNet + +(m::SEResNet)(x) = m.layers(x) + +function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) + _checkconfig(depth, keys(RESNET_CONFIGS)) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, + attn_fn = squeeze_excite) + if pretrain + loadpretrain!(layers, string("SEResNet", depth)) + end + return SEResNet(layers) +end + +backbone(m::SEResNet) = m.layers[1] +classifier(m::SEResNet) = m.layers[2] + +""" + SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) + +Creates a SEResNeXt model with the specified depth, cardinality, and base width. +((reference)[https://arxiv.org/pdf/1709.01507.pdf]) + +# Arguments + + - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `cardinality`: the number of groups to be used in the 3x3 convolution in each block. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels + - `nclasses`: the number of output classes + +!!! warning + + `SEResNeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](#). +""" +struct SEResNeXt + layers::Any +end +@functor SEResNeXt + +(m::SEResNeXt)(x) = m.layers(x) + +function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, + inchannels = 3, nclasses = 1000) + _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, + attn_fn = squeeze_excite) + if pretrain + loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) + end + return SEResNeXt(layers) +end + +backbone(m::SEResNeXt) = m.layers[1] +classifier(m::SEResNeXt) = m.layers[2] diff --git a/src/convnets/resnext.jl b/src/convnets/resnext.jl deleted file mode 100644 index fc00bb180..000000000 --- a/src/convnets/resnext.jl +++ /dev/null @@ -1,126 +0,0 @@ -""" - resnextblock(inplanes, outplanes, cardinality, width, downsample = false) - -Create a basic residual block as defined in the paper for ResNeXt -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `downsample`: set to `true` to downsample the input -""" -function resnextblock(inplanes, outplanes, cardinality, width, downsample = false) - stride = downsample ? 2 : 1 - hidden_channels = cardinality * width - return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false)..., - conv_bn((3, 3), hidden_channels, hidden_channels; - stride = stride, pad = 1, bias = false, groups = cardinality)..., - conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false)...) -end - -""" - resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. relu(x) + relu(y); - block_config, nclasses = 1000) - -Create a ResNeXt model -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments - - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `widen_factor`: the factor by which the width of the bottleneck is increased after each stage - - `connection`: the binary function applied to the output of residual and skip paths in a block - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -function resnext(cardinality, width, widen_factor = 2, - connection = (x, y) -> @. relu(x) + relu(y); - block_config, nclasses = 1000) - inplanes = 64 - baseplanes = 128 - layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3))) - push!(layers, MaxPool((3, 3); stride = (2, 2), pad = (1, 1))) - for (i, nrepeats) in enumerate(block_config) - # output planes within a block - outplanes = baseplanes * widen_factor - # push first skip connection on using first residual - # downsample the residual path if this is the first repetition of a block - push!(layers, - Parallel(connection, - resnextblock(inplanes, outplanes, cardinality, width, i != 1), - skip_projection(inplanes, outplanes, i != 1))) - # push remaining skip connections on using second residual - inplanes = outplanes - for _ in 2:nrepeats - push!(layers, - Parallel(connection, - resnextblock(inplanes, outplanes, cardinality, width, false), - skip_identity(inplanes, outplanes, false))) - end - baseplanes = outplanes - # double width after every cluster of blocks - width *= widen_factor - end - return Chain(Chain(layers), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(inplanes, nclasses))) -end - -""" - ResNeXt(cardinality, width; block_config, nclasses = 1000) - -Create a ResNeXt model -([reference](https://arxiv.org/abs/1611.05431)). - -# Arguments - - - `cardinality`: the number of groups to use for the convolution - - `width`: the number of feature maps in each group in the bottleneck - - `block_config`: a list of the number of residual blocks at each stage - - `nclasses`: the number of output classes -""" -struct ResNeXt - layers::Any -end - -function ResNeXt(cardinality, width; block_config, nclasses = 1000) - layers = resnext(cardinality, width; block_config, nclasses) - return ResNeXt(layers) -end - -@functor ResNeXt - -(m::ResNeXt)(x) = m.layers(x) - -backbone(m::ResNeXt) = m.layers[1] -classifier(m::ResNeXt) = m.layers[2] - -const resnext_config = Dict(50 => (3, 4, 6, 3), - 101 => (3, 4, 23, 3), - 152 => (3, 8, 36, 3)) - -""" - ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000) - -Create a ResNeXt model with specified configuration. Currently supported values for `config` are (50, 101). -([reference](https://arxiv.org/abs/1611.05431)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - -!!! warning - - `ResNeXt` does not currently support pretrained weights. - -See also [`Metalhead.resnext`](#). -""" -function ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, - nclasses = 1000) - @assert config in keys(resnext_config) "`config` must be one of $(sort(collect(keys(resnext_config))))" - model = ResNeXt(cardinality, width; block_config = resnext_config[config], nclasses) - pretrain && loadpretrain!(model, string("ResNeXt", config)) - return model -end diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index c4de36acc..abcdd63f8 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -15,11 +15,7 @@ function fire(inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes) branch_1 = Conv((1, 1), inplanes => squeeze_planes, relu) branch_2 = Conv((1, 1), squeeze_planes => expand1x1_planes, relu) branch_3 = Conv((3, 3), squeeze_planes => expand3x3_planes, relu; pad = 1) - - return Chain(branch_1, - Parallel(cat_channels, - branch_2, - branch_3)) + return Chain(branch_1, Parallel(cat_channels, branch_2, branch_3)) end """ @@ -29,24 +25,22 @@ Create a SqueezeNet ([reference](https://arxiv.org/abs/1602.07360v4)). """ function squeezenet() - layers = Chain(Chain(Conv((3, 3), 3 => 64, relu; stride = 2), - MaxPool((3, 3); stride = 2), - fire(64, 16, 64, 64), - fire(128, 16, 64, 64), - MaxPool((3, 3); stride = 2), - fire(128, 32, 128, 128), - fire(256, 32, 128, 128), - MaxPool((3, 3); stride = 2), - fire(256, 48, 192, 192), - fire(384, 48, 192, 192), - fire(384, 64, 256, 256), - fire(512, 64, 256, 256), - Dropout(0.5), - Conv((1, 1), 512 => 1000, relu)), - AdaptiveMeanPool((1, 1)), - MLUtils.flatten) - - return layers + return Chain(Chain(Conv((3, 3), 3 => 64, relu; stride = 2), + MaxPool((3, 3); stride = 2), + fire(64, 16, 64, 64), + fire(128, 16, 64, 64), + MaxPool((3, 3); stride = 2), + fire(128, 32, 128, 128), + fire(256, 32, 128, 128), + MaxPool((3, 3); stride = 2), + fire(256, 48, 192, 192), + fire(384, 48, 192, 192), + fire(384, 64, 256, 256), + fire(512, 64, 256, 256), + Dropout(0.5), + Conv((1, 1), 512 => 1000, relu)), + AdaptiveMeanPool((1, 1)), + MLUtils.flatten) end """ @@ -65,15 +59,16 @@ See also [`squeezenet`](#). struct SqueezeNet layers::Any end +@functor SqueezeNet function SqueezeNet(; pretrain = false) layers = squeezenet() - pretrain && loadpretrain!(layers, "SqueezeNet") + if pretrain + loadpretrain!(layers, "SqueezeNet") + end return SqueezeNet(layers) end -@functor SqueezeNet - (m::SqueezeNet)(x) = m.layers(x) backbone(m::SqueezeNet) = m.layers[1] diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 56975a124..ccfdd2cff 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -17,7 +17,7 @@ function vgg_block(ifilters, ofilters, depth, batchnorm) layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_bn(k, ifilters, ofilters; pad = p, bias = false)) + append!(layers, conv_norm(k, ifilters, ofilters; pad = p, bias = false)) else push!(layers, Conv(k, ifilters => ofilters, relu; pad = p)) end @@ -52,7 +52,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -63,19 +63,19 @@ Create VGG classifier (fully connected) layers the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `dropout`: the dropout level between each fully connected layer + - `dropout_rate`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +function vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) return Chain(MLUtils.flatten, Dense(Int(prod(imsize)), fcsize, relu), - Dropout(dropout), + Dropout(dropout_rate), Dense(fcsize, fcsize, relu), - Dropout(dropout), + Dropout(dropout_rate), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -90,31 +90,31 @@ Create a VGG model - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `dropout_rate`: dropout level between fully connected layers """ -function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) +function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) return Chain(Chain(conv), class) end -const vgg_conv_config = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], - :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], - :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], - :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) +const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], + :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], + :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], + :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) -const vgg_config = Dict(11 => :A, - 13 => :B, - 16 => :D, - 19 => :E) +const VGG_CONFIGS = Dict(11 => :A, + 13 => :B, + 16 => :D, + 19 => :E) struct VGG layers::Any end """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -126,17 +126,11 @@ Construct a VGG model with the specified input image size. Typically, the image - `nclasses`::Integer : number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `dropout_rate`: dropout level between fully connected layers """ -function VGG(imsize::Dims{2}; - config, inchannels, batchnorm = false, nclasses, fcsize, dropout) - layers = vgg(imsize; config = config, - inchannels = inchannels, - batchnorm = batchnorm, - nclasses = nclasses, - fcsize = fcsize, - dropout = dropout) - +function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, + dropout_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, dropout_rate) return VGG(layers) end @@ -159,13 +153,13 @@ See also [`VGG`](#). - `pretrain`: set to `true` to load pre-trained model weights for ImageNet """ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000) - @assert depth in keys(vgg_config) "depth must be from one in $(sort(collect(keys(vgg_config))))" - model = VGG((224, 224); config = vgg_conv_config[vgg_config[depth]], + _checkconfig(depth, keys(VGG_CONFIGS)) + model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]], inchannels = 3, batchnorm = batchnorm, nclasses = nclasses, fcsize = 4096, - dropout = 0.5) + dropout_rate = 0.5) if pretrain && !batchnorm loadpretrain!(model, string("vgg", depth)) elseif pretrain diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 1034136f3..04be476ff 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,26 +1,43 @@ module Layers using Flux -using Flux: outputsize, Zygote +using Flux: rng_from_array +using CUDA +using NNlib, NNlibCUDA using Functors +using ChainRulesCore using Statistics using MLUtils +using PartialFunctions +using Random include("../utilities.jl") include("attention.jl") +export MHAttention + +include("conv.jl") +export conv_norm, depthwise_sep_conv_norm, invertedresidual + +include("drop.jl") +export DropBlock, DropPath + include("embeddings.jl") +export PatchEmbedding, ViPosEmbedding, ClassTokens + include("mlp.jl") +export mlp_block, gated_mlp_block, create_fc, create_classifier + include("normalise.jl") -include("conv.jl") -include("others.jl") - -export MHAttention, - PatchEmbedding, ViPosEmbedding, ClassTokens, - mlp_block, gated_mlp_block, - LayerScale, DropPath, - ChannelLayerNorm, prenorm, - skip_identity, skip_projection, - conv_bn, depthwise_sep_conv_bn, - invertedresidual, squeeze_excite +export prenorm, ChannelLayerNorm + +include("pool.jl") +export AdaptiveMeanMaxPool + +include("scale.jl") +export LayerScale, inputscale + +include("selayers.jl") +export squeeze_excite, effective_squeeze_excite + end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index a1244a033..e2276aa01 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,14 +1,15 @@ """ - MHAttention(nheads::Integer, qkv_layer, attn_drop, projection) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_rate = 0., proj_dropout_rate = 0.) Multi-head self-attention layer. -# Arguments: +# Arguments - - `nheads`: Number of heads - - `qkv_layer`: layer to be used for getting the query, key and value - - `attn_drop`: dropout rate after the self-attention layer - - `projection`: projection layer to be used after self-attention + - `planes`: number of input channels + - `nheads`: number of heads + - `qkv_bias`: whether to use bias in the layer to get the query, key and value + - `attn_dropout_rate`: dropout rate after the self-attention layer + - `proj_dropout_rate`: dropout rate after the projection layer """ struct MHAttention{P, Q, R} nheads::Int @@ -16,31 +17,17 @@ struct MHAttention{P, Q, R} attn_drop::Q projection::R end +@functor MHAttention -""" - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop = 0., proj_drop = 0.) - -Multi-head self-attention layer. - -# Arguments: - - - `planes`: number of input channels - - `nheads`: number of heads - - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_drop`: dropout rate after the self-attention layer - - `proj_drop`: dropout rate after the projection layer -""" function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_drop = 0.0, proj_drop = 0.0) + attn_dropout_rate = 0.0, proj_dropout_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_drop) - proj = Chain(Dense(planes, planes), Dropout(proj_drop)) + attn_drop = Dropout(attn_dropout_rate) + proj = Chain(Dense(planes, planes), Dropout(proj_dropout_rate)) return MHAttention(nheads, qkv_layer, attn_drop, proj) end -@functor MHAttention - function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} nfeatures, seq_len, batch_size = size(x) x_reshaped = reshape(x, nfeatures, seq_len * batch_size) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 6363946d0..5610d3be2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,40 +1,39 @@ """ - conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, stride = 1, pad = 0, dilation = 1, - groups = 1, [bias, weight, init], initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1) + conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; + norm_layer = BatchNorm, revnorm = false, preact = false, use_norm = true, + stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init]) Create a convolution + batch normalization pair with activation. # Arguments - - `kernelsize`: size of the convolution kernel (tuple) + - `kernel_size`: size of the convolution kernel (tuple) - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `rev`: set to `true` to place the batch norm before the convolution + - `norm_layer`: the normalization layer used + - `revnorm`: set to `true` to place the batch norm before the convolution - `preact`: set to `true` to place the activation function before the batch norm - (only compatible with `rev = false`) - - `use_bn`: set to `false` to disable batch normalization - (only compatible with `rev = false` and `preact = false`) + (only compatible with `revnorm = false`) + - `use_norm`: set to `false` to disable normalization + (only compatible with `revnorm = false` and `preact = false`) - `stride`: stride of the convolution kernel - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) - - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) - - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) """ -function conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, preact = false, use_bn = true, - initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1.0f-5, momentum = 1.0f-1, - kwargs...) - if !use_bn - (preact || rev) ? throw("preact only supported with `use_bn = true`") : - return [Conv(kernelsize, inplanes => outplanes, activation; kwargs...)] +function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; + norm_layer = BatchNorm, revnorm = false, preact = false, use_norm = true, + kwargs...) + if !use_norm + if (preact || revnorm) + throw(ArgumentError("`preact` only supported with `use_norm = true`")) + else + return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] + end end - layers = [] - if rev + if revnorm activations = (conv = activation, bn = identity) bnplanes = inplanes else @@ -42,127 +41,59 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; bnplanes = outplanes end if preact - rev ? throw(ArgumentError("preact and rev cannot be set at the same time")) : - activations = (conv = activation, bn = identity) + if revnorm + throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time")) + else + activations = (conv = activation, bn = identity) + end end - push!(layers, - Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...)) - push!(layers, - BatchNorm(Int(bnplanes), activations.bn; - initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum)) - return rev ? reverse(layers) : layers + layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...), + norm_layer(bnplanes, activations.bn)] + return revnorm ? reverse(layers) : layers +end + +function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; + kwargs...) + inplanes, outplanes = ch + return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end """ - depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, use_bn = (true, true), - stride = 1, pad = 0, dilation = 1, [bias, weight, init], - initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1) + depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; + revnorm = false, use_norm = (true, true), + stride = 1, pad = 0, dilation = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: - - a `kernelsize` depthwise convolution from `inplanes => inplanes` - - a batch norm layer + `activation` (if `use_bn[1] == true`; otherwise `activation` is applied to the convolution output) - - a `kernelsize` convolution from `inplanes => outplanes` - - a batch norm layer + `activation` (if `use_bn[2] == true`; otherwise `activation` is applied to the convolution output) + - a `kernel_size` depthwise convolution from `inplanes => inplanes` + - a batch norm layer + `activation` (if `use_norm[1] == true`; otherwise `activation` is applied to the convolution output) + - a `kernel_size` convolution from `inplanes => outplanes` + - a batch norm layer + `activation` (if `use_norm[2] == true`; otherwise `activation` is applied to the convolution output) See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). # Arguments - - `kernelsize`: size of the convolution kernel (tuple) + - `kernel_size`: size of the convolution kernel (tuple) - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `rev`: set to `true` to place the batch norm before the convolution - - `use_bn`: a tuple of two booleans to specify whether to use batch normalization for the first and second convolution + - `revnorm`: set to `true` to place the batch norm before the convolution + - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and second convolution - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) - - `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) - - `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) """ -function depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; - rev = false, use_bn = (true, true), - initβ = Flux.zeros32, initγ = Flux.ones32, - ϵ = 1.0f-5, momentum = 1.0f-1, - stride = 1, kwargs...) - return vcat(conv_bn(kernelsize, inplanes, inplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, - ϵ = ϵ, momentum = momentum, use_bn = use_bn[1], - stride = stride, groups = Int(inplanes), kwargs...), - conv_bn((1, 1), inplanes, outplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, use_bn = use_bn[2], - ϵ = ϵ, momentum = momentum)) -end - -""" - skip_projection(inplanes, outplanes, downsample = false) - -Create a skip projection -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps - - `downsample`: set to `true` to downsample the input -""" -function skip_projection(inplanes, outplanes, downsample = false) - return downsample ? - Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)) : - Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)) -end - -# array -> PaddedView(0, array, outplanes) for zero padding arrays -""" - skip_identity(inplanes, outplanes[, downsample]) - -Create a identity projection -([reference](https://arxiv.org/abs/1512.03385v1)). - -# Arguments: - - - `inplanes`: the number of input feature maps - - `outplanes`: the number of output feature maps - - `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#). -""" -function skip_identity(inplanes, outplanes) - if outplanes > inplanes - return Chain(MaxPool((1, 1); stride = 2), - y -> cat_channels(y, - zeros(eltype(y), - size(y, 1), - size(y, 2), - outplanes - inplanes, size(y, 4)))) - else - return identity - end -end -skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) - -""" - squeeze_excite(channels, reduction = 4) - -Squeeze and excitation layer used by MobileNet variants -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `channels`: the number of input/output feature maps - - `reduction = 4`: the reduction factor for the number of hidden feature maps - (must be ≥ 1) -""" -function squeeze_excite(channels, reduction = 4) - @assert (reduction>=1) "`reduction` must be >= 1" - return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - conv_bn((1, 1), channels, channels ÷ reduction, relu; - bias = false)..., - conv_bn((1, 1), channels ÷ reduction, channels, hardσ)...), - .*) +function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; + norm_layer = BatchNorm, revnorm = false, + use_norm = (true, true), stride = 1, kwargs...) + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; + norm_layer, revnorm, use_norm = use_norm[1], stride, + groups = inplanes, kwargs...), + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, + use_norm = use_norm[2])) end """ @@ -174,28 +105,29 @@ Create a basic inverted residual block for MobileNet variants # Arguments - - `kernel_size`: The kernel size of the convolutional layers - - `inplanes`: The number of input feature maps + - `kernel_size`: kernel size of the convolutional layers + - `inplanes`: number of input feature maps - `hidden_planes`: The number of feature maps in the hidden layer - `outplanes`: The number of output feature maps - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)). - Must be ≥ 1 or `nothing` for no squeeze and excite layer. """ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation = relu; stride, reduction = nothing) @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 conv1 = (inplanes == hidden_planes) ? identity : - Chain(conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false)) - selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes, reduction) + Chain(conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false)) + selayer = isnothing(reduction) ? identity : + squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ, + norm_layer = BatchNorm) invres = Chain(conv1, - conv_bn(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad = pad, groups = hidden_planes)..., + conv_norm(kernel_size, hidden_planes, hidden_planes, activation; + bias = false, stride, pad = pad, groups = hidden_planes)..., selayer, - conv_bn((1, 1), hidden_planes, outplanes, identity; bias = false)...) + conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...) return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres end diff --git a/src/layers/drop.jl b/src/layers/drop.jl new file mode 100644 index 000000000..b4a882cff --- /dev/null +++ b/src/layers/drop.jl @@ -0,0 +1,147 @@ +# Generates the mask to be used for `DropBlock` +@inline function _dropblock_mask(rng, x, gamma, clipped_block_size) + block_mask = rand_like(rng, x) + block_mask .= block_mask .< gamma + return 1 .- maxpool(block_mask, (clipped_block_size, clipped_block_size); + stride = 1, pad = clipped_block_size ÷ 2) +end +ChainRulesCore.@non_differentiable _dropblock_mask(rng, x, gamma, clipped_block_size) + +""" + dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale, active::Bool = true) + +The dropblock function. If `active` is `true`, for each input, it zeroes out continguous +regions of size `block_size` in the input. Otherwise, it simply returns the input `x`. + +# Arguments + + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. + - `x`: input array + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, + refer to [the paper](https://arxiv.org/abs/1810.12890). + +If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. +""" +# TODO add experimental `DropBlock` options from timm such as gaussian noise and +# more precise `DropBlock` to deal with edges (#188) +function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size, + gamma_scale) where {T} + H, W, _, _ = size(x) + total_size = H * W + clipped_block_size = min(block_size, min(H, W)) + gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 / + ((W - block_size + 1) * (H - block_size + 1)) + block_mask = dropblock_mask(rng, x, gamma, clipped_block_size) + normalize_scale = length(block_mask) / sum(block_mask) .+ T(1e-6) + return x .* block_mask .* normalize_scale +end + +## bs is `clipped_block_size` +# Dispatch for GPU +dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) +function dropblock_mask(rng, x::CuArray, gamma, bs) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports CUDA.RNG for CuArrays.")) +end +# Dispatch for CPU +dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) + +mutable struct DropBlock{F, R <: AbstractRNG} + drop_block_prob::F + block_size::Integer + gamma_scale::F + active::Union{Bool, Nothing} + rng::R +end +@functor DropBlock +trainable(a::DropBlock) = (;) + +function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_scale) + @assert 0 ≤ drop_block_prob ≤ 1 + "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0 ≤ gamma_scale ≤ 1 + return "gamma_scale must be between 0 and 1, got $gamma_scale" +end +function _dropblock_checks(x, drop_block_prob, gamma_scale) + throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock.")) +end +ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_scale) + +function (m::DropBlock)(x) + _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) + if Flux._isactive(m) + return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) + else + return x + end +end + +function Flux.testmode!(m::DropBlock, mode = true) + return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end + +""" + DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = rng_from_array()) + +The `DropBlock` layer. While training, it zeroes out continguous regions of +size `block_size` in the input. During inference, it simply returns the input `x`. +((reference)[https://arxiv.org/abs/1810.12890]) + +# Arguments + + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, + refer to [the paper](https://arxiv.org/abs/1810.12890). + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. +""" +function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = rng_from_array()) + if drop_block_prob == 0.0 + return identity + end + return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) +end + +function Base.show(io::IO, d::DropBlock) + print(io, "DropBlock(", d.drop_block_prob) + print(io, ", block_size = $(repr(d.block_size))") + print(io, ", gamma_scale = $(repr(d.gamma_scale))") + return print(io, ")") +end + +""" + DropPath(p; [rng = rng_from_array(x)]) + +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and +`identity` otherwise. +([reference](https://arxiv.org/abs/1603.09382)) + +This layer can be used to drop certain blocks in a residual structure and allow them to +propagate completely through the skip connection. It can be used in two ways: either with +all blocks having the same survival probability or with a linear scaling rule across the +blocks. This is performed only at training time. At test time, the `DropPath` layer is +equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropPath` returning `identity`, which may not be desirable. This usually happens with + a low number of blocks and a high base survival probability, so it is recommended to + use a fixed base survival probability across blocks. If this is not possible, then + a lower base survival probability is recommended. + +# Arguments + + - `p`: rate of Stochastic Depth. + - `rng`: can be used to pass in a custom RNG instead of the default. See `Flux.Dropout` + for more information on the behaviour of this argument. Custom RNGs are only supported + on the CPU. +""" +DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index ad079db9d..3e85f18d9 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -8,10 +8,10 @@ _flatten_spatial(x) = permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1 Patch embedding layer used by many vision transformer-like models to split the input image into patches. -# Arguments: +# Arguments - `imsize`: the size of the input image - - `inchannels`: the number of channels in the input. The default value is 3. + - `inchannels`: the number of channels in the input. - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding - `norm_layer`: the normalization layer - by default the identity function but otherwise takes a diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 25ead874b..a3bdb0fb5 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,6 +1,6 @@ """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0., activation = gelu) + dropout_rate = 0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -9,18 +9,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `dropout_rate`: Dropout rate. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), - Dense(hidden_planes, outplanes), Dropout(dropout)) + dropout_rate = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout_rate), + Dense(hidden_planes, outplanes), Dropout(dropout_rate)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0., activation = gelu) + outplanes::Integer = inplanes; dropout_rate = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -31,16 +31,50 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `dropout_rate`: Dropout rate. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0.0, activation = gelu) + outplanes::Integer = inplanes; dropout_rate = 0.0, + activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(dropout), + Dropout(dropout_rate), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(dropout)) + Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) + +""" + create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = 0.0, use_conv = false) + +Creates a classifier head to be used for models. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. +""" +function create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = 0.0, use_conv = false) + # Pooling + if pool_layer === identity + @assert use_conv + "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" + end + flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end + global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer + # Fully-connected layer + fc = use_conv ? Conv((1, 1), inplanes => nclasses) : Dense(inplanes => nclasses) + return Chain(global_pool, Dropout(dropout_rate), fc) +end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 4f69dab03..bb83f042d 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,7 +2,7 @@ prenorm(planes, fn) = Chain(LayerNorm(planes), fn) """ - ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1f-5) + ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) A variant of LayerNorm where the input is normalised along the channel dimension. The input is expected to have channel dimension with size @@ -16,12 +16,11 @@ struct ChannelLayerNorm{D, T} diag::D ϵ::T end - @functor ChannelLayerNorm -(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) - -function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5) +function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) diag = Flux.Scale(1, 1, sz, λ) return ChannelLayerNorm(diag, ϵ) end + +(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) diff --git a/src/layers/pool.jl b/src/layers/pool.jl new file mode 100644 index 000000000..1962ab0fb --- /dev/null +++ b/src/layers/pool.jl @@ -0,0 +1,16 @@ +""" + AdaptiveMeanMaxPool(output_size = (1, 1); connection = +) + +A type of adaptive pooling layer which uses both mean and max pooling and combines them to +produce a single output. Note that this is equivalent to +`Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))` + +# Arguments + + - `output_size`: The size of the output after pooling. + - `connection`: The connection type to use. +""" +function AdaptiveMeanMaxPool(connection, output_size = (1, 1)) + return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) +end +AdaptiveMeanMaxPool(output_size::Tuple = (1, 1)) = AdaptiveMeanMaxPool(+, output_size) diff --git a/src/layers/others.jl b/src/layers/scale.jl similarity index 55% rename from src/layers/others.jl rename to src/layers/scale.jl index 770bccebd..965b50f38 100644 --- a/src/layers/others.jl +++ b/src/layers/scale.jl @@ -1,3 +1,13 @@ +""" + inputscale(λ; activation = identity) + +Scale the input by a scalar `λ` and applies an activation function to it. +Equivalent to `activation.(λ .* x)`. +""" +inputscale(λ; activation = identity) = _input_scale$(λ, activation) +_input_scale(λ, activation, x) = activation.(λ .* x) +_input_scale(λ, ::typeof(identity), x) = λ .* x + """ LayerScale(λ, planes::Integer) @@ -12,15 +22,3 @@ Creates a `Flux.Scale` layer that performs "`LayerScale`" function LayerScale(planes::Integer, λ) return λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity end - -""" - DropPath(p) - -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. -([reference](https://arxiv.org/abs/1603.09382)) - -# Arguments - - - `p`: rate of Stochastic Depth. -""" -DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl new file mode 100644 index 000000000..db0f3715d --- /dev/null +++ b/src/layers/selayers.jl @@ -0,0 +1,47 @@ +""" + squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, + activation = relu, gate_activation = sigmoid, norm_layer = identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0.0)) + +Creates a squeeze-and-excitation layer used in MobileNets and SE-Nets. + +# Arguments + + - `inplanes`: The number of input feature maps + - `reduction`: The reduction factor for the number of hidden feature maps + - `rd_divisor`: The divisor for the number of hidden feature maps. + - `activation`: The activation function for the first convolution layer + - `gate_activation`: The activation function for the gate layer + - `norm_layer`: The normalization layer to be used after the convolution layers + - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer +""" +function squeeze_excite(inplanes; reduction = 16, rd_divisor = 8, + activation = relu, gate_activation = sigmoid, + norm_layer = planes -> identity, + rd_planes = _round_channels(inplanes ÷ reduction, rd_divisor, 0)) + layers = [AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes => rd_planes), + norm_layer(rd_planes), + activation, + Conv((1, 1), rd_planes => inplanes), + norm_layer(inplanes), + gate_activation] + return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) +end + +""" + effective_squeeze_excite(inplanes, gate_activation = sigmoid) + +Effective squeeze-and-excitation layer. +(reference: [CenterMask : Real-Time Anchor-Free Instance Segmentation](https://arxiv.org/abs/1911.06667)) + +# Arguments + + - `inplanes`: The number of input feature maps + - `gate_activation`: The activation function for the gate layer +""" +function effective_squeeze_excite(inplanes; gate_activation = sigmoid, kwargs...) + return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), + Conv((1, 1), inplanes, inplanes), + gate_activation), .*) +end diff --git a/src/mixers/core.jl b/src/mixers/core.jl new file mode 100644 index 000000000..9f9d3b305 --- /dev/null +++ b/src/mixers/core.jl @@ -0,0 +1,43 @@ +""" + mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., + depth = 12, nclasses = 1000, kwargs...) + +Creates a model with the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)). + +# Arguments + + - `block`: the type of mixer block to use in the model - architecture dependent + (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) + - `imsize`: the size of the input image + - `inchannels`: the number of input channels + - `norm_layer`: the normalization layer to use in the model + - `patch_size`: the size of the patches + - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) + - `drop_path_rate`: Stochastic depth rate + - `depth`: the number of blocks in the model + - `nclasses`: number of output classes + - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if + not specified. +""" +function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, + norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), + embedplanes = 512, drop_path_rate = 0.0, + depth = 12, nclasses = 1000, kwargs...) + npatches = prod(imsize .÷ patch_size) + dp_rates = linear_scheduler(drop_path_rate; depth) + layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), + Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], + kwargs...) + for i in 1:depth])) + classification_head = Chain(norm_layer(embedplanes), seconddimmean, + Dense(embedplanes, nclasses)) + return Chain(layers, classification_head) +end + +# Configurations for MLPMixer models +const MIXER_CONFIGS = Dict(:small => Dict(:depth => 8, :planes => 512), + :base => Dict(:depth => 12, :planes => 768), + :large => Dict(:depth => 24, :planes => 1024), + :huge => Dict(:depth => 32, :planes => 1280)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl new file mode 100644 index 000000000..9ebd2dce3 --- /dev/null +++ b/src/mixers/gmlp.jl @@ -0,0 +1,110 @@ +""" + SpatialGatingUnit(norm, proj) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `norm`: the normalisation layer to use + - `proj`: the projection layer to use +""" +struct SpatialGatingUnit{T, F} + norm::T + proj::F +end +@functor SpatialGatingUnit + +""" + SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `norm_layer`: the normalisation layer to use +""" +function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) + gateplanes = planes ÷ 2 + norm = norm_layer(gateplanes) + proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) + return SpatialGatingUnit(norm, proj) +end + +function (m::SpatialGatingUnit)(x) + u, v = chunk(x, 2; dims = 1) + v = m.norm(v) + v = m.proj(permutedims(v, (2, 1, 3))) + return u .* permutedims(v, (2, 1, 3)) +end + +""" + spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, + norm_layer = LayerNorm, dropout_rate = 0.0, drop_path_rate = 0.0, + activation = gelu) + +Creates a feedforward block based on the gMLP model architecture described in the paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block + - `norm_layer`: the normalisation layer to use + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks +""" +function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, + mlp_layer = gated_mlp_block, dropout_rate = 0.0, + drop_path_rate = 0.0, + activation = gelu) + channelplanes = Int(mlp_ratio * planes) + sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) + return SkipConnection(Chain(norm_layer(planes), + mlp_layer(sgu, planes, channelplanes; activation, + dropout_rate), + DropPath(drop_path_rate)), +) +end + +struct gMLP + layers::Any +end +@functor gMLP + +""" + gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) + +Creates a model with the gMLP architecture. +([reference](https://arxiv.org/abs/2105.08050)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] + layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, + patch_size, embedplanes, drop_path_rate, depth, nclasses) + return gMLP(layers) +end + +(m::gMLP)(x) = m.layers(x) + +backbone(m::gMLP) = m.layers[1] +classifier(m::gMLP) = m.layers[2] diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl new file mode 100644 index 000000000..7b6d4aa09 --- /dev/null +++ b/src/mixers/mlpmixer.jl @@ -0,0 +1,69 @@ +""" + mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout_rate = 0., drop_path_rate = 0., activation = gelu) + +Creates a feedforward block for the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)) + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP + and/or the channel mixing MLP as a ratio to the number of planes in the block. + - `mlp_layer`: the MLP layer to use in the block + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks +""" +function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] + return Chain(SkipConnection(Chain(LayerNorm(planes), + swapdims((2, 1, 3)), + mlp_layer(npatches, tokenplanes; activation, + dropout_rate), + swapdims((2, 1, 3)), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(LayerNorm(planes), + mlp_layer(planes, channelplanes; activation, + dropout_rate), + DropPath(drop_path_rate)), +)) +end + +struct MLPMixer + layers::Any +end +@functor MLPMixer + +""" + MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) + +Creates a model with the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] + layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, + nclasses) + return MLPMixer(layers) +end + +(m::MLPMixer)(x) = m.layers(x) + +backbone(m::MLPMixer) = m.layers[1] +classifier(m::MLPMixer) = m.layers[2] diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl new file mode 100644 index 000000000..17e340310 --- /dev/null +++ b/src/mixers/resmlp.jl @@ -0,0 +1,72 @@ +""" + resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, + activation = gelu, λ = 1e-4) + +Creates a block for the ResMixer architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments + + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block + - `mlp_layer`: the MLP block to use + - `dropout_rate`: the dropout rate to use in the MLP blocks + - `drop_path_rate`: Stochastic depth rate + - `activation`: the activation function to use in the MLP blocks + - `λ`: initialisation constant for the LayerScale +""" +function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, + dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu, + λ = 1e-4) + return Chain(SkipConnection(Chain(Flux.Scale(planes), + swapdims((2, 1, 3)), + Dense(npatches, npatches), + swapdims((2, 1, 3)), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(Flux.Scale(planes), + mlp_layer(planes, Int(mlp_ratio * planes); + dropout_rate, + activation), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +)) +end + +struct ResMLP + layers::Any +end +@functor ResMLP + +""" + ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + +Creates a model with the ResMLP architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments + + - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `patch_size`: the size of the patches + - `imsize`: the size of the input image + - `drop_path_rate`: Stochastic depth rate + - `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] + layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, + drop_path_rate, depth, nclasses) + return ResMLP(layers) +end + +(m::ResMLP)(x) = m.layers(x) + +backbone(m::ResMLP) = m.layers[1] +classifier(m::ResMLP) = m.layers[2] diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl deleted file mode 100644 index 942abc823..000000000 --- a/src/other/mlpmixer.jl +++ /dev/null @@ -1,297 +0,0 @@ -""" - mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0., drop_path_rate = 0., activation = gelu) - -Creates a feedforward block for the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)) - -# Arguments: - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP - and/or the channel mixing MLP as a ratio to the number of planes in the block. - - `mlp_layer`: the MLP layer to use in the block - - `dropout`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks -""" -function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu) - tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] - return Chain(SkipConnection(Chain(LayerNorm(planes), - swapdims((2, 1, 3)), - mlp_layer(npatches, tokenplanes; activation, dropout), - swapdims((2, 1, 3)), - DropPath(drop_path_rate)), +), - SkipConnection(Chain(LayerNorm(planes), - mlp_layer(planes, channelplanes; activation, dropout), - DropPath(drop_path_rate)), +)) -end - -""" - mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., - depth = 12, nclasses = 1000, kwargs...) - -Creates a model with the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)). - -# Arguments - - - `block`: the type of mixer block to use in the model - architecture dependent - (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) - - `imsize`: the size of the input image - - `inchannels`: the number of input channels - - `norm_layer`: the normalization layer to use in the model - - `patch_size`: the size of the patches - - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) - - `drop_path_rate`: Stochastic depth rate - - `depth`: the number of blocks in the model - - `nclasses`: number of output classes - - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if - not specified. -""" -function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, - norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0.0, - depth = 12, nclasses = 1000, kwargs...) - npatches = prod(imsize .÷ patch_size) - dp_rates = LinRange{Float32}(0.0, drop_path_rate, depth) - layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), - Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], - kwargs...) - for i in 1:depth])) - - classification_head = Chain(norm_layer(embedplanes), seconddimmean, - Dense(embedplanes, nclasses)) - return Chain(layers, classification_head) -end - -# Configurations for MLPMixer models -mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), - :base => Dict(:depth => 12, :planes => 768), - :large => Dict(:depth => 24, :planes => 1024), - :huge => Dict(:depth => 32, :planes => 1280)) - -struct MLPMixer - layers::Any -end - -""" - MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) - -Creates a model with the MLPMixer architecture. -([reference](https://arxiv.org/pdf/2105.01601)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, - nclasses) - return MLPMixer(layers) -end - -@functor MLPMixer - -(m::MLPMixer)(x) = m.layers(x) - -backbone(m::MLPMixer) = m.layers[1] -classifier(m::MLPMixer) = m.layers[2] - -""" - resmixerblock(planes, npatches; dropout = 0., drop_path_rate = 0., mlp_ratio = 4.0, - activation = gelu, λ = 1e-4) - -Creates a block for the ResMixer architecture. -([reference](https://arxiv.org/abs/2105.03404)). - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number - of planes in the block - - `mlp_layer`: the MLP block to use - - `dropout`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks - - `λ`: initialisation constant for the LayerScale -""" -function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) - return Chain(SkipConnection(Chain(Flux.Scale(planes), - swapdims((2, 1, 3)), - Dense(npatches, npatches), - swapdims((2, 1, 3)), - LayerScale(planes, λ), - DropPath(drop_path_rate)), +), - SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); dropout, - activation), - LayerScale(planes, λ), - DropPath(drop_path_rate)), +)) -end - -struct ResMLP - layers::Any -end - -""" - ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), - drop_path_rate = 0., nclasses = 1000) - -Creates a model with the ResMLP architecture. -([reference](https://arxiv.org/abs/2105.03404)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, - drop_path_rate, depth, nclasses) - return ResMLP(layers) -end - -@functor ResMLP - -(m::ResMLP)(x) = m.layers(x) - -backbone(m::ResMLP) = m.layers[1] -classifier(m::ResMLP) = m.layers[2] - -""" - SpatialGatingUnit(norm, proj) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `norm`: the normalisation layer to use - - `proj`: the projection layer to use -""" -struct SpatialGatingUnit{T, F} - norm::T - proj::F -end - -""" - SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `norm_layer`: the normalisation layer to use -""" -function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - gateplanes = planes ÷ 2 - norm = norm_layer(gateplanes) - proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) - return SpatialGatingUnit(norm, proj) -end - -@functor SpatialGatingUnit - -function (m::SpatialGatingUnit)(x) - u, v = chunk(x, 2; dims = 1) - v = m.norm(v) - v = m.proj(permutedims(v, (2, 1, 3))) - return u .* permutedims(v, (2, 1, 3)) -end - -""" - spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, dropout = 0.0, drop_path_rate = 0., - activation = gelu) - -Creates a feedforward block based on the gMLP model architecture described in the paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number - of planes in the block - - `norm_layer`: the normalisation layer to use - - `dropout`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate - - `activation`: the activation function to use in the MLP blocks -""" -function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, dropout = 0.0, - drop_path_rate = 0.0, - activation = gelu) - channelplanes = Int(mlp_ratio * planes) - sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) - return SkipConnection(Chain(norm_layer(planes), - mlp_layer(sgu, planes, channelplanes; activation, dropout), - DropPath(drop_path_rate)), +) -end - -struct gMLP - layers::Any -end - -""" - gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) - -Creates a model with the gMLP architecture. -([reference](https://arxiv.org/abs/2105.08050)). - -# Arguments - - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` - - `patch_size`: the size of the patches - - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate - - `nclasses`: number of output classes - -See also [`Metalhead.mlpmixer`](#). -""" -function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), - imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] - layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, - patch_size, embedplanes, drop_path_rate, depth, nclasses) - return gMLP(layers) -end - -@functor gMLP - -(m::gMLP)(x) = m.layers(x) - -backbone(m::gMLP) = m.layers[1] -classifier(m::gMLP) = m.layers[2] diff --git a/src/utilities.jl b/src/utilities.jl index 0c4f46796..981777228 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -10,22 +10,26 @@ function _round_channels(channels, divisor, min_value = divisor) end """ - addrelu(x, y) + addact(activation = relu, xs...) + +Convenience function for applying an activation function to the output after +summing up the input arrays. Useful as the `connection` argument for the block +function in [`resnet`](#). -Convenience function for `(x, y) -> @. relu(x + y)`. -Useful as the `connection` argument for [`resnet`](#). See also [`reluadd`](#). """ -addrelu(x, y) = @. relu(x + y) +addact(activation = relu, xs...) = activation(sum(xs)) """ - reluadd(x, y) + actadd(activation = relu, xs...) + +Convenience function for adding input arrays after applying an activation +function to them. Useful as the `connection` argument for the block function in +[`resnet`](#). -Convenience function for `(x, y) -> @. relu(x) + relu(y)`. -Useful as the `connection` argument for [`resnet`](#). See also [`addrelu`](#). """ -reluadd(x, y) = @. relu(x) + relu(y) +actadd(activation = relu, xs...) = sum(activation.(x) for x in xs) """ cat_channels(x, y, zs...) @@ -36,16 +40,6 @@ Convenient reduction operator for use with `Parallel`. """ cat_channels(xy...) = cat(xy...; dims = Val(3)) -""" - inputscale(λ; activation = identity) - -Scale the input by a scalar `λ` and applies an activation function to it. -Equivalent to `activation.(λ .* x)`. -""" -inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation) -_input_scale(x, λ, activation) = activation.(λ .* x) -_input_scale(x, λ, ::typeof(identity)) = λ .* x - """ swapdims(perm) @@ -67,3 +61,18 @@ function _maybe_big_show(io, model) show(io, model) end end + +""" + linear_scheduler(drop_path_rate = 0.0; start_value = 0.0, depth) + +Returns the dropout rates for a given depth using the linear scaling rule. +""" +function linear_scheduler(drop_rate = 0.0; depth, start_value = 0.0) + return LinRange(start_value, drop_rate, depth) +end + +# Utility function for depth and configuration checks in models +function _checkconfig(config, configs) + @assert config in configs + return "Invalid configuration. Must be one of $(sort(collect(configs)))." +end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 012bfef9d..1fece2191 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,23 +10,25 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `dropout`: dropout rate + - `dropout_rate`: dropout rate """ -function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.0) +function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, - MHAttention(planes, nheads; attn_drop = dropout, - proj_drop = dropout)), +), + MHAttention(planes, nheads; + attn_dropout_rate = dropout_rate, + proj_dropout_rate = dropout_rate)), + +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - dropout)), +)) + dropout_rate)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, + emb_dropout_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -40,35 +42,36 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `dropout`: dropout rate + - `dropout_rate`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, + emb_dropout_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] - "Pool type must be either :class (class token) or :mean (mean pooling)" + "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), ViPosEmbedding(embedplanes, npatches + 1), - Dropout(emb_dropout), - transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout), + Dropout(emb_dropout_rate), + transformer_encoder(embedplanes, depth, nheads; mlp_ratio, + dropout_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end -vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), - :small => (depth = 12, embedplanes = 384, nheads = 6), - :base => (depth = 12, embedplanes = 768, nheads = 12), - :large => (depth = 24, embedplanes = 1024, nheads = 16), - :huge => (depth = 32, embedplanes = 1280, nheads = 16), - :giant => (depth = 40, embedplanes = 1408, nheads = 16, - mlp_ratio = 48 // 11), - :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, - mlp_ratio = 64 // 13)) +const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), + :small => (depth = 12, embedplanes = 384, nheads = 6), + :base => (depth = 12, embedplanes = 768, nheads = 12), + :large => (depth = 24, embedplanes = 1024, nheads = 16), + :huge => (depth = 32, embedplanes = 1280, nheads = 16), + :giant => (depth = 40, embedplanes = 1408, nheads = 16, + mlp_ratio = 48 // 11), + :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, + mlp_ratio = 64 // 13)) """ ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3, @@ -92,13 +95,13 @@ See also [`Metalhead.vit`](#). struct ViT layers::Any end +@functor ViT function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) - @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" - kwargs = vit_configs[mode] + _checkconfig(mode, keys(VIT_CONFIGS)) + kwargs = VIT_CONFIGS[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) - return ViT(layers) end @@ -106,5 +109,3 @@ end backbone(m::ViT) = m.layers[1] classifier(m::ViT) = m.layers[2] - -@functor ViT diff --git a/test/convnets.jl b/test/convnets.jl index 97cfd846e..e62b14299 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -3,11 +3,9 @@ @test size(model(x_256)) == (1000, 1) @test_throws ArgumentError AlexNet(pretrain = true) @test gradtest(model, x_256) + _gc() end -GC.safepoint() -GC.gc() - @testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] m = VGG(sz, batchnorm = bn) @@ -18,62 +16,116 @@ GC.gc() @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "ResNet" begin + # Tests for pretrained ResNets + ## TODO: find a way to port pretrained models to the new ResNet API @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] m = ResNet(sz) - @test size(m(x_256)) == (1000, 1) - if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) - else - @test_throws ArgumentError ResNet(sz, pretrain = true) + @test size(m(x_224)) == (1000, 1) + # if (ResNet, sz) in PRETRAINED_MODELS + # @test acctest(ResNet(sz, pretrain = true)) + # else + # @test_throws ArgumentError ResNet(sz, pretrain = true) + # end + end + + @testset "resnet" begin + @testset for block_fn in [:basicblock, :bottleneck] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3] + ] + @testset for layers in layer_list + drop_list = [ + (dropout_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), + (dropout_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), + (dropout_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), + ] + @testset for drop_rates in drop_list + m = Metalhead.resnet(block_fn, layers; drop_rates...) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end + end end - @test gradtest(m, x_256) - GC.safepoint() - GC.gc() end - @testset "Shortcut C" begin - m = Metalhead.resnet(Metalhead.basicblock, :C; - channel_config = [1, 1], - block_config = [2, 2, 2, 2]) - @test size(m(x_256)) == (1000, 1) - @test gradtest(m, x_256) + @testset "WideResNet" begin + @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(ResNet(sz, pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) + end + end end end -GC.safepoint() -GC.gc() - @testset "ResNeXt" begin @testset for depth in [50, 101, 152] - m = ResNeXt(depth) + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + end +end + +@testset "SEResNet" begin + @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) @test size(m(x_224)) == (1000, 1) - if ResNeXt in PRETRAINED_MODELS - @test acctest(ResNeXt(depth, pretrain = true)) + if (SEResNet, depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth, pretrain = true)) else - @test_throws ArgumentError ResNeXt(depth, pretrain = true) + @test_throws ArgumentError SEResNet(depth, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() +@testset "SEResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + end +end @testset "EfficientNet" begin - @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8] + @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] # preferred image resolution scaling - r = Metalhead.efficientnet_global_configs[name][1] + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[name][1] x = rand(Float32, r, r, 3, 1) m = EfficientNet(name) @test size(m(x)) == (1000, 1) @@ -83,14 +135,10 @@ GC.gc() @test_throws ArgumentError EfficientNet(name, pretrain = true) end @test gradtest(m, x) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) @@ -100,11 +148,9 @@ GC.gc() @test_throws ArgumentError GoogLeNet(pretrain = true) end @test gradtest(m, x_224) + _gc() end -GC.safepoint() -GC.gc() - @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @testset "Inceptionv3" begin @@ -117,8 +163,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "Inceptionv4" begin m = Inceptionv4() @test size(m(x_299)) == (1000, 2) @@ -129,8 +174,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "InceptionResNetv2" begin m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) @@ -141,8 +185,7 @@ GC.gc() end @test gradtest(m, x_299) end - GC.safepoint() - GC.gc() + _gc() @testset "Xception" begin m = Xception() @test size(m(x_299)) == (1000, 2) @@ -153,11 +196,9 @@ GC.gc() end @test gradtest(m, x_299) end + _gc() end -GC.safepoint() -GC.gc() - @testset "SqueezeNet" begin m = SqueezeNet() @test size(m(x_224)) == (1000, 1) @@ -167,15 +208,12 @@ GC.gc() @test_throws ArgumentError SqueezeNet(pretrain = true) end @test gradtest(m, x_224) + _gc() end -GC.safepoint() -GC.gc() - @testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) if (DenseNet, sz) in PRETRAINED_MODELS @test acctest(DenseNet(sz, pretrain = true)) @@ -183,18 +221,13 @@ GC.gc() @test_throws ArgumentError DenseNet(sz, pretrain = true) end @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end -GC.safepoint() -GC.gc() - @testset "MobileNet" verbose = true begin @testset "MobileNetv1" begin m = MobileNetv1() - @test size(m(x_224)) == (1000, 1) if MobileNetv1 in PRETRAINED_MODELS @test acctest(MobileNetv1(pretrain = true)) @@ -203,8 +236,7 @@ GC.gc() end @test gradtest(m, x_224) end - GC.safepoint() - GC.gc() + _gc() @testset "MobileNetv2" begin m = MobileNetv2() @test size(m(x_224)) == (1000, 1) @@ -215,12 +247,10 @@ GC.gc() end @test gradtest(m, x_224) end - GC.safepoint() - GC.gc() + _gc() @testset "MobileNetv3" verbose = true begin @testset for mode in [:small, :large] m = MobileNetv3(mode) - @test size(m(x_224)) == (1000, 1) if (MobileNetv3, mode) in PRETRAINED_MODELS @test acctest(MobileNetv3(mode; pretrain = true)) @@ -228,12 +258,11 @@ GC.gc() @test_throws ArgumentError MobileNetv3(mode; pretrain = true) end @test gradtest(m, x_224) + _gc() end end end -GC.safepoint() -GC.gc() @testset "ConvNeXt" verbose = true begin @testset for mode in [:small, :base, :large, :tiny, :xlarge] @@ -241,22 +270,16 @@ GC.gc() m = ConvNeXt(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end end -GC.safepoint() -GC.gc() - @testset "ConvMixer" verbose = true begin @testset for mode in [:small, :base, :large] m = ConvMixer(mode) - @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) - GC.safepoint() - GC.gc() + _gc() end end diff --git a/test/mixers.jl b/test/mixers.jl new file mode 100644 index 000000000..885ff5838 --- /dev/null +++ b/test/mixers.jl @@ -0,0 +1,32 @@ +@testset "MLPMixer" begin + @testset for mode in [:small, :base, :large] #:huge] + @testset for drop_path_rate in [0.0, 0.5] + m = MLPMixer(mode; drop_path_rate) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end + end +end + +@testset "ResMLP" begin + @testset for mode in [:small, :base, :large] #:huge] + @testset for drop_path_rate in [0.0, 0.5] + m = ResMLP(mode; drop_path_rate) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end + end +end + +@testset "gMLP" begin + @testset for mode in [:small, :base, :large] #:huge] + @testset for drop_path_rate in [0.0, 0.5] + m = gMLP(mode; drop_path_rate) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end + end +end diff --git a/test/other.jl b/test/other.jl deleted file mode 100644 index 3c1752f3a..000000000 --- a/test/other.jl +++ /dev/null @@ -1,35 +0,0 @@ -@testset "MLPMixer" begin - @testset for mode in [:small, :base] # :large, # :huge] - @testset for drop_path_rate in [0.0, 0.5] - m = MLPMixer(mode; drop_path_rate) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - GC.safepoint() - GC.gc() - end - end -end - -@testset "ResMLP" begin - @testset for mode in [:small, :base] # :large, # :huge] - @testset for drop_path_rate in [0.0, 0.5] - m = ResMLP(mode; drop_path_rate) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - GC.safepoint() - GC.gc() - end - end -end - -@testset "gMLP" begin - @testset for mode in [:small, :base] # :large, # :huge] - @testset for drop_path_rate in [0.0, 0.5] - m = gMLP(mode; drop_path_rate) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - GC.safepoint() - GC.gc() - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index f1a9787b9..622bfc394 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,11 @@ const PRETRAINED_MODELS = [ (ResNet, 152), ] +function _gc() + GC.safepoint() + GC.gc(true) +end + function gradtest(model, input) y, pb = Zygote.pullback(() -> model(input), Flux.params(model)) gs = pb(ones(Float32, size(y))) @@ -24,8 +29,8 @@ function gradtest(model, input) end function normalize_imagenet(data) - cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1)) - cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1)) + cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1)) + cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1)) return (data .- cmean) ./ cstd end @@ -33,7 +38,7 @@ end const TEST_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224)) # CHW -> WHC -const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,1)) |> normalize_imagenet +const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> normalize_imagenet # image net labels const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) @@ -53,18 +58,12 @@ x_256 = rand(Float32, 256, 256, 3, 1) include("convnets.jl") end -GC.safepoint() -GC.gc() - -# Other tests -@testset verbose = true "Other" begin - include("other.jl") +# Mixer tests +@testset verbose = true "Mixers" begin + include("mixers.jl") end -GC.safepoint() -GC.gc() - # ViT tests @testset verbose = true "ViTs" begin - include("vit-based.jl") + include("vits.jl") end diff --git a/test/vit-based.jl b/test/vits.jl similarity index 52% rename from test/vit-based.jl rename to test/vits.jl index 9dc348819..13733ddec 100644 --- a/test/vit-based.jl +++ b/test/vits.jl @@ -1,9 +1,8 @@ @testset "ViT" begin - for mode in [:small, :base, :large] # :tiny, #,:huge, :giant, :gigantic] + for mode in [:tiny, :small, :base, :large, :huge] #:giant, #:gigantic m = ViT(mode) @test size(m(x_256)) == (1000, 1) @test gradtest(m, x_256) - GC.safepoint() - GC.gc() + _gc() end end