Skip to content

Commit b49fc40

Browse files
committed
Some more cleanup
1 parent e048d78 commit b49fc40

File tree

5 files changed

+61
-73
lines changed

5 files changed

+61
-73
lines changed

src/convnets/inceptions/xception.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int
3434
end
3535
push!(layers, relu)
3636
append!(layers,
37-
dwsep_conv_norm((3, 3), inc, outc; pad = 1, use_norm = (false, false)))
37+
dwsep_conv_norm((3, 3), inc, outc; pad = 1, norm_layer = identity))
3838
push!(layers, BatchNorm(outc))
3939
end
4040
layers = start_with_relu ? layers : layers[2:end]

src/convnets/resnets/core.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
133133
norm_layer, revnorm)...,
134134
attn_fn(outplanes),
135135
]
136-
return Chain(filter(!=(identity), layers)...)
136+
return Chain(filter!(!=(identity), layers)...)
137137
end
138138

139139
## Downsample layers
@@ -345,34 +345,42 @@ Wide ResNet, ResNeXt and Res2Net. For an _even_ more generic model API, see [`Me
345345
346346
# Arguments
347347
348-
- `block_type`: The type of block to be used in the model. This can be one of [`Metalhead.basicblock`](@ref),
349-
[`Metalhead.bottleneck`](@ref) and [`Metalhead.bottle2neck`](@ref). `basicblock` is used in the
348+
- `block_type`: The type of block to be used in the model. This can be one of [`Metalhead.basicblock`](@ref),
349+
[`Metalhead.bottleneck`](@ref) and [`Metalhead.bottle2neck`](@ref). `basicblock` is used in the
350350
original ResNet paper for ResNet-18 and ResNet-34, and `bottleneck` is used in the original ResNet-50
351351
and ResNet-101 models, as well as for the Wide ResNet and ResNeXt models. `bottle2neck` is introduced in
352352
the `Res2Net` paper.
353-
- `block_repeats`: A `Vector` of integers specifying the number of times each block is repeated
353+
- `block_repeats`: A `Vector` of integers specifying the number of times each block is repeated
354354
in each stage of the ResNet model. For example, `[3, 4, 6, 3]` is the configuration used in
355355
ResNet-50, which has 3 blocks in the first stage, 4 blocks in the second stage, 6 blocks in the
356356
third stage and 3 blocks in the fourth stage.
357-
- `downsample_opt`: A `NTuple` of two callbacks that are used to determine the downsampling
357+
- `downsample_opt`: A `NTuple` of two callbacks that are used to determine the downsampling
358358
operation to be used in the model. The first callback is used to determine the convolutional
359359
operation to be used in the downsampling operation and the second callback is used to determine
360360
the identity operation to be used in the downsampling operation.
361-
- `cardinality`: The number of groups to be used in the 3x3 convolutional layer in the bottleneck
361+
- `cardinality`: The number of groups to be used in the 3x3 convolutional layer in the bottleneck
362362
block. This is usually modified from the default value of `1` in the ResNet models to `32` or `64`
363363
in the `ResNeXt` models.
364-
- `base_width`: The base width of the convolutional layer in the blocks of the model.
365-
- `inplanes`: The number of input channels in the first convolutional layer.
366-
- `reduction_factor`: The reduction factor used in the model.
367-
- `connection`: This is a function that determines the residual connection in the model. For
364+
- `base_width`: The base width of the convolutional layer in the blocks of the model.
365+
- `inplanes`: The number of input channels in the first convolutional layer.
366+
- `reduction_factor`: The reduction factor used in the model.
367+
- `connection`: This is a function that determines the residual connection in the model. For
368368
`resnets`, either of [`Metalhead.addact`](@ref) or [`Metalhead.actadd`](@ref) is recommended.
369-
- `norm_layer`: The normalisation layer to be used in the model.
370-
- `revnorm`: set to `true` to place the normalisation layers before the convolutions
371-
- `attn_fn`: A callback that is used to determine the attention function to be used in the model.
369+
- `norm_layer`: The normalisation layer to be used in the model.
370+
- `revnorm`: set to `true` to place the normalisation layers before the convolutions
371+
- `attn_fn`: A callback that is used to determine the attention function to be used in the model.
372372
See [`Metalhead.Layers.squeeze_excite`](@ref) for an example.
373-
- `pool_layer`: A fully-insta
374-
- `use_conv`: Set to true to use convolutions instead of identity operations in the model.
375-
- `dropblock_prob`: The probability of using DropBlock in the model.
373+
- `pool_layer`: A fully-instantiated pooling layer passed in to be used by the classifier head.
374+
For example, `AdaptiveMeanPool((1, 1))` is used in the ResNet family by default, but something
375+
like `MeanPool((3, 3))` should also work provided the dimensions after applying the pooling
376+
layer are compatible with the rest of the classifier head.
377+
- `use_conv`: Set to true to use convolutions instead of identity operations in the model.
378+
- `dropblock_prob`: `DropBlock` probability to be used in the model. Set to `nothing` to disable
379+
DropBlock. See [`Metalhead.DropBlock`](@ref) for more details.
380+
- `stochastic_depth_prob`: `StochasticDepth` probability to be used in the model. Set to `nothing` to disable
381+
StochasticDepth. See [`Metalhead.StochasticDepth`](@ref) for more details.
382+
- `dropout_prob`: `Dropout` probability to be used in the classifier head. Set to `nothing` to
383+
disable Dropout.
376384
"""
377385
function resnet(block_type, block_repeats::AbstractVector{<:Integer},
378386
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);

src/layers/conv.jl

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
22
conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
33
activation = relu; norm_layer = BatchNorm, revnorm::Bool = false,
4-
eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true,
5-
stride::Integer = 1, pad::Integer = 0, dilation::Integer = 1,
6-
groups::Integer = 1, [bias, weight, init])
4+
preact::Bool = false, stride::Integer = 1, pad::Integer = 0,
5+
dilation::Integer = 1, groups::Integer = 1, [bias, weight, init])
76
87
Create a convolution + normalisation layer pair with activation.
98
@@ -14,33 +13,27 @@ Create a convolution + normalisation layer pair with activation.
1413
- `outplanes`: number of output feature maps
1514
- `activation`: the activation function for the final layer
1615
- `norm_layer`: the normalisation layer used. Note that using `identity` as the normalisation
17-
layer will result in no normalisation being applied i.e. this will be the same as
18-
setting `use_norm = false`.
16+
layer will result in no normalisation being applied. (This is only compatible with `preact`
17+
and `revnorm` both set to `false`.)
1918
- `revnorm`: set to `true` to place the normalisation layer before the convolution
2019
- `preact`: set to `true` to place the activation function before the normalisation layer
2120
(only compatible with `revnorm = false`)
22-
- `use_norm`: set to `false` to disable normalisation
23-
(only compatible with `revnorm = false` and `preact = false`)
21+
- `bias`: bias for the convolution kernel. This is set to `false` by default if
22+
`norm_layer` is not `identity` and `true` otherwise.
2423
- `stride`: stride of the convolution kernel
2524
- `pad`: padding of the convolution kernel
2625
- `dilation`: dilation of the convolution kernel
2726
- `groups`: groups for the convolution kernel
28-
- `bias`: bias for the convolution kernel. This is set to `false` by default if
29-
`use_norm = true`.
3027
- `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](@ref))
3128
"""
3229
function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
3330
activation = relu; norm_layer = BatchNorm, revnorm::Bool = false,
34-
eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true,
35-
bias = !use_norm, kwargs...)
36-
# no normalization layer (including case where normalization layer is identity)
37-
use_norm = use_norm && norm_layer !== identity
38-
if !use_norm
31+
preact::Bool = false, bias = !(norm_layer !== identity), kwargs...)
32+
# no normalization layer
33+
if !(norm_layer !== identity)
3934
if preact || revnorm
40-
throw(ArgumentError("`preact` only supported with `use_norm = true`. Check if
41-
`use_norm = false` is intended. Note that it is also possible to trigger this
42-
error if you set `norm_layer` to `identity` since that returns the same
43-
behaviour as `use_norm`."))
35+
throw(ArgumentError("`preact` only supported with `norm_layer !== identity`.
36+
Check if a non-`identity` norm layer is intended."))
4437
else
4538
# early return if no norm layer is required
4639
return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)]
@@ -64,7 +57,7 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
6457
end
6558
# layers
6659
layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...),
67-
norm_layer(normplanes, activations.norm; ϵ = eps)]
60+
norm_layer(normplanes, activations.norm)]
6861
return revnorm ? reverse(layers) : layers
6962
end
7063

@@ -86,6 +79,8 @@ TensorFlow implementation.
8679
"""
8780
function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu;
8881
kwargs...)
89-
return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm,
90-
eps = 1.0f-3, kwargs...)
82+
# TensorFlow uses a default epsilon of 1e-3 for BatchNorm
83+
norm_layer = (args...; kwargs...) -> BatchNorm(args...; ϵ = 1.0f-3, kwargs...)
84+
return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = norm_layer,
85+
kwargs...)
9186
end

src/layers/mbconv.jl

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
"""
22
dwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
3-
activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false,
4-
stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true),
5-
pad::Integer = 0, [bias, weight, init])
3+
activation = relu; norm_layer = BatchNorm, stride::Integer = 1,
4+
bias::Bool = !(norm_layer !== identity), pad::Integer = 0, [bias, weight, init])
65
76
Create a depthwise separable convolution chain as used in MobileNetv1.
87
This is sequence of layers:
98
109
- a `kernel_size` depthwise convolution from `inplanes => inplanes`
11-
- a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise
10+
- a (batch) normalisation layer + `activation` (if `norm_layer !== identity`; otherwise
1211
`activation` is applied to the convolution output)
1312
- a `kernel_size` convolution from `inplanes => outplanes`
14-
- a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise
13+
- a (batch) normalisation layer + `activation` (if `norm_layer !== identity`; otherwise
1514
`activation` is applied to the convolution output)
1615
1716
See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
@@ -22,25 +21,19 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
2221
- `inplanes`: number of input feature maps
2322
- `outplanes`: number of output feature maps
2423
- `activation`: the activation function for the final layer
25-
- `revnorm`: set to `true` to place the batch norm before the convolution
26-
- `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and
27-
second convolution
28-
- `bias`: a tuple of two booleans to specify whether to use bias for the first and second
29-
convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and
30-
`use_norm[1] == true`.
24+
- `norm_layer`: the normalisation layer used. Note that using `identity` as the normalisation
25+
layer will result in no normalisation being applied.
26+
- `bias`: whether to use bias in the convolution layers.
3127
- `stride`: stride of the first convolution kernel
3228
- `pad`: padding of the first convolution kernel
3329
- `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](@ref))
3430
"""
3531
function dwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
36-
activation = relu; norm_layer = BatchNorm, eps::Float32 = 1.0f-5,
37-
use_norm::NTuple{2, Bool} = (true, true), stride::Integer = 1,
38-
bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...)
32+
activation = relu; norm_layer = BatchNorm, stride::Integer = 1,
33+
bias::Bool = !(norm_layer !== identity), kwargs...)
3934
return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, norm_layer,
40-
use_norm = use_norm[1], stride, bias = bias[1],
41-
groups = inplanes, kwargs...), # depthwise convolution
42-
conv_norm((1, 1), inplanes, outplanes, activation; eps, norm_layer,
43-
use_norm = use_norm[2], bias = bias[2])) # pointwise convolution
35+
stride, bias, groups = inplanes, kwargs...), # depthwise convolution
36+
conv_norm((1, 1), inplanes, outplanes, activation; eps, norm_layer, bias)) # pointwise convolution
4437
end
4538

4639
"""

src/layers/selayers.jl

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,30 @@
11
"""
2-
squeeze_excite(inplanes::Integer, squeeze_planes::Integer;
3-
norm_layer = planes -> identity, activation = relu,
4-
gate_activation = sigmoid)
5-
6-
squeeze_excite(inplanes::Integer; reduction::Real = 16,
7-
norm_layer = planes -> identity, activation = relu,
8-
gate_activation = sigmoid)
2+
squeeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels,
3+
norm_layer = identity, activation = relu, gate_activation = sigmoid)
94
105
Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets.
116
127
# Arguments
138
149
- `inplanes`: The number of input feature maps
15-
- `squeeze_planes`: The number of feature maps in the intermediate layers. Alternatively,
16-
specify the keyword arguments `reduction` and `rd_divisior`, which determine the number
17-
of feature maps in the intermediate layers from the number of input feature maps as:
18-
`squeeze_planes = _round_channels(inplanes ÷ reduction)`. (See [`_round_channels`](@ref) for details.)
10+
- `reduction`: The reduction factor for the number of hidden feature maps in the
11+
squeeze and excite layer. The number of hidden feature maps is calculated as
12+
`round_fn(inplanes / reduction)`.
13+
- `round_fn`: The function to round the number of reduced feature maps.
1914
- `activation`: The activation function for the first convolution layer
2015
- `gate_activation`: The activation function for the gate layer
2116
- `norm_layer`: The normalization layer to be used after the convolution layers
2217
- `rd_planes`: The number of hidden feature maps in a squeeze and excite layer
2318
"""
24-
function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; norm_layer = identity,
25-
activation = relu, gate_activation = sigmoid)
19+
function squeeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels,
20+
norm_layer = identity, activation = relu, gate_activation = sigmoid)
21+
squeeze_planes = round_fn(inplanes ÷ reduction)
2622
return SkipConnection(Chain(AdaptiveMeanPool((1, 1)),
2723
conv_norm((1, 1), inplanes, squeeze_planes, activation;
2824
norm_layer)...,
2925
conv_norm((1, 1), squeeze_planes, inplanes,
3026
gate_activation; norm_layer)...), .*)
3127
end
32-
function squeeze_excite(inplanes::Integer; reduction::Real = 16,
33-
round_fn = _round_channels, kwargs...)
34-
return squeeze_excite(inplanes, round_fn(inplanes / reduction); kwargs...)
35-
end
3628

3729
"""
3830
effective_squeeze_excite(inplanes, gate_activation = sigmoid)

0 commit comments

Comments
 (0)