Skip to content

Commit cd486df

Browse files
committed
Refine invertedresidual
1 parent 73df024 commit cd486df

File tree

4 files changed

+17
-11
lines changed

4 files changed

+17
-11
lines changed

src/convnets/efficientnet.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ function efficientnet(scalings, block_configs;
3636
out_channels = _round_channels(scalew(o), 8)
3737
repeats = scaled(n)
3838
push!(blocks,
39-
invertedresidual(k, in_channels, in_channels * e, out_channels, swish;
39+
invertedresidual((k, k), in_channels, in_channels * e, out_channels, swish;
4040
stride = s, reduction = 4))
4141
for _ in 1:(repeats - 1)
4242
push!(blocks,
43-
invertedresidual(k, out_channels, out_channels * e, out_channels, swish;
43+
invertedresidual((k, k), out_channels, out_channels * e, out_channels,
44+
swish;
4445
stride = 1, reduction = 4))
4546
end
4647
end

src/convnets/mobilenet/mobilenetv2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla
3030
outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8)
3131
for i in 1:n
3232
push!(layers,
33-
invertedresidual(3, inplanes, inplanes * t, outplanes, a;
33+
invertedresidual((3, 3), inplanes, inplanes * t, outplanes, a;
3434
stride = i == 1 ? s : 1))
3535
inplanes = outplanes
3636
end

src/convnets/mobilenet/mobilenetv3.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla
3636
outplanes = _round_channels(c * width_mult, 8)
3737
explanes = _round_channels(inplanes * t, 8)
3838
push!(layers,
39-
invertedresidual(k, inplanes, explanes, outplanes, a;
39+
invertedresidual((k, k), inplanes, explanes, outplanes, a;
4040
stride = s, reduction = r))
4141
inplanes = outplanes
4242
end

src/layers/conv.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,28 @@ Create a basic inverted residual block for MobileNet variants
114114
- `reduction`: The reduction factor for the number of hidden feature maps
115115
in a squeeze and excite layer (see [`squeeze_excite`](#)).
116116
"""
117-
function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes,
118-
activation = relu; stride, reduction = nothing)
117+
function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer,
118+
outplanes::Integer, activation = relu; stride::Integer,
119+
reduction::Union{Nothing, Integer} = nothing)
119120
@assert stride in [1, 2] "`stride` has to be 1 or 2"
120121
pad = @. (kernel_size - 1) ÷ 2
121-
conv1 = (inplanes == hidden_planes) ? identity :
122-
Chain(conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false))
122+
conv1 = (inplanes == hidden_planes) ? (identity,) :
123+
conv_norm((1, 1), inplanes, hidden_planes, activation; bias = false)
123124
selayer = isnothing(reduction) ? identity :
124125
squeeze_excite(hidden_planes; reduction, activation, gate_activation = hardσ,
125126
norm_layer = BatchNorm)
126-
invres = Chain(conv1,
127+
invres = Chain(conv1...,
127128
conv_norm(kernel_size, hidden_planes, hidden_planes, activation;
128129
bias = false, stride, pad = pad, groups = hidden_planes)...,
129130
selayer,
130131
conv_norm((1, 1), hidden_planes, outplanes, identity; bias = false)...)
131132
return (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres
132133
end
133134

134-
function invertedresidual(kernel_size::Integer, args...; kwargs...)
135-
return invertedresidual((kernel_size, kernel_size), args...; kwargs...)
135+
function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer,
136+
activation = relu; stride::Integer, expansion,
137+
reduction::Union{Nothing, Integer} = nothing)
138+
hidden_planes = Int(inplanes * expansion)
139+
return invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation;
140+
stride, reduction)
136141
end

0 commit comments

Comments
 (0)