diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a1bf822b9..5e14d6c49 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - '["EfficientNet"]' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' + - '[r"Res2Net", r"Res2NeXt"]' - '"Inception"' - '"DenseNet"' - '["ConvNeXt", "ConvMixer"]' diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 78073c154..aa236454c 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -26,6 +26,7 @@ include("convnets/resnets/core.jl") include("convnets/resnets/resnet.jl") include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") +include("convnets/resnets/res2net.jl") ## Inceptions include("convnets/inception/googlenet.jl") include("convnets/inception/inceptionv3.jl") @@ -57,7 +58,7 @@ include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, - WideResNet, ResNeXt, SEResNet, SEResNeXt, + WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, @@ -65,8 +66,8 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, - :GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, - :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, + :Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4, + :Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/inception/googlenet.jl b/src/convnets/inception/googlenet.jl index 11d4dd7d3..54f814479 100644 --- a/src/convnets/inception/googlenet.jl +++ b/src/convnets/inception/googlenet.jl @@ -36,7 +36,7 @@ Create an Inception-v1 model (commonly referred to as GoogLeNet) - `nclasses`: the number of output classes """ -function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000) +function googlenet(; dropout_rate = 0.4, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(Conv((7, 7), inchannels => 64; stride = 2, pad = 3), MaxPool((3, 3); stride = 2, pad = 1), Conv((1, 1), 64 => 64), @@ -53,7 +53,7 @@ function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000) MaxPool((3, 3); stride = 2, pad = 1), _inceptionblock(832, 256, 160, 320, 32, 128, 128), _inceptionblock(832, 384, 192, 384, 48, 128, 128)) - return Chain(backbone, create_classifier(1024, nclasses; dropout_rate = 0.4)) + return Chain(backbone, create_classifier(1024, nclasses; dropout_rate)) end """ diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index c2855191b..a8bfbaefa 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -96,7 +96,8 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, end """ - InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -118,9 +119,8 @@ end @functor InceptionResNetv2 function InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, - dropout_rate = 0.0, nclasses::Integer = 1000) - layers = inceptionresnetv2(; inchannels, dropout_rate, nclasses) + layers = inceptionresnetv2(; inchannels, nclasses) if pretrain loadpretrain!(layers, "InceptionResNetv2") end diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl index 83a37f649..bc5ec3a2b 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inception/inceptionv3.jl @@ -133,7 +133,7 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes """ -function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000) +function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., conv_norm((3, 3), 32, 32)..., conv_norm((3, 3), 32, 64; pad = 1)..., @@ -152,7 +152,7 @@ function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000) inceptionv3_d(768), inceptionv3_e(1280), inceptionv3_e(2048)) - return Chain(backbone, create_classifier(2048, nclasses; dropout_rate = 0.2)) + return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) end """ diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index 1c97daddc..4964f3ca1 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -66,8 +66,7 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ xception_block(728, 1024, 2; stride = 2, grow_at_start = false), depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)..., depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...) - classifier = create_classifier(2048, nclasses; dropout_rate) - return Chain(backbone, classifier) + return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) end """ diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 7b7c268e2..699edcbe8 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,8 +1,9 @@ """ - basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, revnorm = false, - drop_block = identity, drop_path = identity, - attn_fn = planes -> identity) + basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). @@ -11,10 +12,11 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385 - `inplanes`: number of input feature maps - `planes`: number of feature maps for the block - `stride`: the stride of the block - - `reduction_factor`: the factor by which the input feature maps - are reduced before the first convolution. + - `reduction_factor`: the factor by which the input feature maps are reduced before + the first convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. + - `revnorm`: set to `true` to place the normalisation layer before the convolution - `drop_block`: the drop block layer - `drop_path`: the drop path layer - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. @@ -36,11 +38,12 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, end """ - bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64, - reduction_factor = 1, activation = relu, - norm_layer = BatchNorm, revnorm = false, - drop_block = identity, drop_path = identity, - attn_fn = planes -> identity) + bottleneck(inplanes::Integer, planes::Integer; stride::Integer, + cardinality::Integer = 1, base_width::Integer = 64, + reduction_factor::Integer = 1, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false, + drop_block = identity, drop_path = identity, + attn_fn = planes -> identity) Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). @@ -55,6 +58,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512. convolution. - `activation`: the activation function to use. - `norm_layer`: the normalization layer to use. + - `revnorm`: set to `true` to place the normalisation layer before the convolution - `drop_block`: the drop block layer - `drop_path`: the drop path layer - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. @@ -112,7 +116,7 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) end # Shortcut configurations for the ResNet models -const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity), +const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), :C => (downsample_conv, downsample_conv), :D => (downsample_pool, downsample_identity)) @@ -153,7 +157,8 @@ on how to use this function. shows peformance improvements over the `:deep` stem in some cases. - `inchannels`: The number of channels in the input. - - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two. + - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + + normalization with a stride of two. - `norm_layer`: The normalisation layer used in the stem. - `activation`: The activation function used in the stem. """ @@ -253,8 +258,6 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; stride = stride_fn(stage_idx, block_idx) downsample_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_tuple[1] : downsample_tuple[2] - # DropBlock, DropPath both take in rates based on a linear scaling schedule - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx drop_path = DropPath(pathschedule[schedule_idx]) drop_block = DropBlock(blockschedule[schedule_idx]) block = bottleneck(inplanes, planes; stride, cardinality, base_width, @@ -289,20 +292,19 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte return Chain(backbone, classifier_fn(nfeaturemaps)) end -function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer}; - downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity), +function resnet(block_type, block_repeats::AbstractVector{<:Integer}, + downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), - inchannels::Integer = 3, stem_fn = resnet_stem, - connection = addact, activation = relu, norm_layer = BatchNorm, - revnorm::Bool = false, attn_fn = planes -> identity, - pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false, - drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0, - nclasses::Integer = 1000) + inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, + activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, + attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), + use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0, + dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) # Block builder - if block_type == :basicblock + if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, @@ -310,14 +312,26 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer}; drop_block_rate, drop_path_rate, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt) - elseif block_type == :bottleneck + downsample_tuple = downsample_opt, + kwargs...) + elseif block_type == bottleneck get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, - reduction_factor, activation, norm_layer, - revnorm, attn_fn, drop_block_rate, drop_path_rate, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_block_rate, drop_path_rate, stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt) + downsample_tuple = downsample_opt, + kwargs...) + elseif block_type == bottle2neck + @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`" + @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`" + @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`" + get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, + activation, norm_layer, revnorm, attn_fn, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, + kwargs...) else # TODO: write better message when we have link to dev docs for resnet throw(ArgumentError("Unknown block type $block_type")) @@ -328,12 +342,16 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer}; connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) - return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...) + return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...) end # block-layer configurations for ResNet-like models -const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]), - 34 => (:basicblock, [3, 4, 6, 3]), - 50 => (:bottleneck, [3, 4, 6, 3]), - 101 => (:bottleneck, [3, 4, 23, 3]), - 152 => (:bottleneck, [3, 8, 36, 3])) +const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), + 34 => (basicblock, [3, 4, 6, 3]), + 50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) + +const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), + 101 => (bottleneck, [3, 4, 23, 3]), + 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl new file mode 100644 index 000000000..8e054da82 --- /dev/null +++ b/src/convnets/resnets/res2net.jl @@ -0,0 +1,159 @@ +""" + bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, + cardinality::Integer = 1, base_width::Integer = 26, + scale::Integer = 4, activation = relu, norm_layer = BatchNorm, + revnorm::Bool = false, attn_fn = planes -> identity) + +Creates a bottleneck block as described in the Res2Net paper. +([reference](https://arxiv.org/abs/1904.01169)) + +# Arguments + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `cardinality`: the number of groups in the 3x3 convolutions. + - `base_width`: the number of output feature maps for each convolutional group. + - `scale`: the number of feature groups in the block. See the [paper](https://arxiv.org/abs/1904.01169) + for more details. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `revnorm`: set to `true` to place the batch norm before the convolution + - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. +""" +function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, + cardinality::Integer = 1, base_width::Integer = 26, + scale::Integer = 4, activation = relu, is_first::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, + attn_fn = planes -> identity) + width = fld(planes * base_width, 64) * cardinality + outplanes = planes * 4 + pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity + conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride, + pad = 1, groups = cardinality, bias = false)...) + for _ in 1:max(1, scale - 1)] + reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : + Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) + tuplify = if is_first + x -> tuple(x...) + else + x -> tuple(x[1], tuple(x[2:end]...)) + end + layers = [conv_norm((1, 1), inplanes => width * scale, activation; + norm_layer, revnorm, bias = false)..., + chunk$(; size = width, dims = 3), tuplify, reslayer, + conv_norm((1, 1), width * scale => outplanes, activation; + norm_layer, revnorm, bias = false)..., + attn_fn(outplanes)] + return Chain(filter(!=(identity), layers)...) +end + +function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 26, scale::Integer = 4, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # This is needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + # `resnet_stride` is a callback that the user can tweak to change the stride of the + # blocks. It defaults to the standard behaviour as in the paper + stride = stride_fn(stage_idx, block_idx) + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_tuple[1] : downsample_tuple[2] + is_first = (stride > 1 || downsample_fn != downsample_tuple[2]) ? true : false + block = bottle2neck(inplanes, planes; stride, cardinality, base_width, scale, + activation, is_first, norm_layer, revnorm, attn_fn) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) + return block, downsample + end + return get_layers +end + +""" + Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4, + base_width::Integer = 26, inchannels::Integer = 3, + nclasses::Integer = 1000) + +Creates a Res2Net model with the specified depth, scale, and base width. +([reference](https://arxiv.org/abs/1904.01169)) + +# Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `scale`: the number of feature groups in the block. See the + [paper](https://arxiv.org/abs/1904.01169) for more details. + - `base_width`: the number of feature maps in each group. + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes +""" +struct Res2Net + layers::Any +end +@functor Res2Net + +function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4, + base_width::Integer = 26, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(depth, keys(LRESNET_CONFIGS)) + layers = resnet(bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale, + inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("Res2Net", depth, "_", base_width, "x", scale)) + end + return Res2Net(layers) +end + +(m::Res2Net)(x) = m.layers(x) + +backbone(m::Res2Net) = m.layers[1] +classifier(m::Res2Net) = m.layers[2] + +""" + Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4, + base_width::Integer = 4, cardinality::Integer = 8, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Creates a Res2NeXt model with the specified depth, scale, base width and cardinality. +([reference](https://arxiv.org/abs/1904.01169)) + +# Arguments + - `depth`: one of `[50, 101, 152]`. The depth of the Res2Net model. + - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet + - `scale`: the number of feature groups in the block. See the + [paper](https://arxiv.org/abs/1904.01169) for more details. + - `base_width`: the number of feature maps in each group. + - `cardinality`: the number of groups in the 3x3 convolutions. + - `inchannels`: the number of input channels. + - `nclasses`: the number of output classes +""" +struct Res2NeXt + layers::Any +end +@functor Res2NeXt + +function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4, + base_width::Integer = 4, cardinality::Integer = 8, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(depth, keys(LRESNET_CONFIGS)) + layers = resnet(bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale, + cardinality, inchannels, nclasses) + if pretrain + loadpretrain!(layers, + string("Res2NeXt", depth, "_", base_width, "x", cardinality, + "x", scale)) + end + return Res2NeXt(layers) +end + +(m::Res2NeXt)(x) = m.layers(x) + +backbone(m::Res2NeXt) = m.layers[1] +classifier(m::Res2NeXt) = m.layers[2] diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 8bbbe0a80..cdccddd4b 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -57,8 +57,8 @@ end function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(depth, [50, 101]) - layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses) + _checkconfig(depth, keys(LRESNET_CONFIGS)) + layers = resnet(LRESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses) if pretrain loadpretrain!(layers, string("WideResNet", depth)) end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index e3556ac95..8c43d2f62 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -29,8 +29,8 @@ end function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) - layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) + _checkconfig(depth, keys(LRESNET_CONFIGS)) + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d")) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index b020861bc..da074e57d 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -69,8 +69,8 @@ end function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) - layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, + _checkconfig(depth, keys(LRESNET_CONFIGS)) + layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c355eac2f..de214bcbc 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -16,7 +16,7 @@ Create a convolution + batch normalization pair with activation. - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - `norm_layer`: the normalization layer used - - `revnorm`: set to `true` to place the batch norm before the convolution + - `revnorm`: set to `true` to place the normalisation layer before the convolution - `preact`: set to `true` to place the activation function before the batch norm (only compatible with `revnorm = false`) - `use_norm`: set to `false` to disable normalization diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl index b784a8f8e..37cc271fb 100644 --- a/src/mixers/mlpmixer.jl +++ b/src/mixers/mlpmixer.jl @@ -56,7 +56,8 @@ struct MLPMixer end @functor MLPMixer -function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), +function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), + patch_size::Dims{2} = (16, 16), inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(mixerblock, imsize; patch_size, MIXER_CONFIGS[config]..., inchannels, diff --git a/src/utilities.jl b/src/utilities.jl index 833f87ce0..f5737831c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -38,7 +38,10 @@ Concatenate `x` and `y` (and any `z`s) along the channel dimension (third dimens Equivalent to `cat(x, y, zs...; dims=3)`. Convenient reduction operator for use with `Parallel`. """ -cat_channels(xy...) = cat(xy...; dims = Val(3)) +cat_channels(xs::AbstractArray...) = cat(xs...; dims = Val(3)) +cat_channels(x::AbstractArray, y::Tuple) = cat_channels(x, y...) +cat_channels(x::Tuple, y::AbstractArray...) = cat_channels(x..., y...) +cat_channels(x::Tuple) = cat_channels(x...) """ swapdims(perm) diff --git a/test/convnets.jl b/test/convnets.jl index 1f9a0ca98..6d7dab496 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -33,7 +33,7 @@ end end @testset "resnet" begin - @testset for block_fn in [:basicblock, :bottleneck] + @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] layer_list = [ [2, 2, 2, 2], [3, 4, 6, 3], @@ -121,6 +121,45 @@ end end end +@testset "Res2Net" begin + @testset for (base_width, scale) in [(26, 4), (48, 2), (14, 8), (26, 6), (26, 8)] + m = Res2Net(50; base_width, scale) + @test size(m(x_224)) == (1000, 1) + if (Res2Net, depth, base_width, scale) in PRETRAINED_MODELS + @test acctest(Res2Net(50; base_width, scale, pretrain = true)) + else + @test_throws ArgumentError Res2Net(50; base_width, scale, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + @testset for (base_width, scale) in [(26, 4)] + m = Res2Net(101; base_width, scale) + @test size(m(x_224)) == (1000, 1) + if (Res2Net, depth, base_width, scale) in PRETRAINED_MODELS + @test acctest(Res2Net(101; base_width, scale, pretrain = true)) + else + @test_throws ArgumentError Res2Net(101; base_width, scale, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + +@testset "Res2NeXt" begin + @testset for depth in [50, 101] + m = Res2NeXt(depth) + @test size(m(x_224)) == (1000, 1) + if (Res2NeXt, depth) in PRETRAINED_MODELS + @test acctest(Res2NeXt(depth, pretrain = true)) + else + @test_throws ArgumentError Res2NeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end +end + @testset "EfficientNet" begin @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] # preferred image resolution scaling