Skip to content

Commit aa1d37d

Browse files
committed
Account for width scaling in input channel size
1 parent 61d886b commit aa1d37d

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/convnets/efficientnet.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,23 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
2323
function efficientnet(scalings, block_config;
2424
inchannels = 3, nclasses = 1000, max_width = 1280)
2525
wscale, dscale = scalings
26-
out_channels = _round_channels(32, 8)
26+
scalew(w) = wscale 1 ? w : ceil(Int64, wscale * w)
27+
scaled(d) = dscale 1 ? d : ceil(Int64, dscale * d)
28+
29+
out_channels = _round_channels(scalew(32), 8)
2730
stem = conv_bn((3, 3), inchannels, out_channels, swish;
2831
bias = false, stride = 2, pad = SamePad())
2932

3033
blocks = []
3134
for (n, k, s, e, i, o) in block_config
32-
in_channels = _round_channels(i, 8)
33-
out_channels = _round_channels(wscale 1 ? o : ceil(Int64, wscale * o), 8)
34-
repeat = dscale 1 ? n : ceil(Int64, dscale * n)
35+
in_channels = _round_channels(scalew(i), 8)
36+
out_channels = _round_channels(scalew(o), 8)
37+
repeats = scaled(n)
3538

3639
push!(blocks,
3740
invertedresidual(k, in_channels, in_channels * e, out_channels, swish;
3841
stride = s, reduction = 4))
39-
for _ in 1:(repeat - 1)
42+
for _ in 1:(repeats - 1)
4043
push!(blocks,
4144
invertedresidual(k, out_channels, out_channels * e, out_channels, swish;
4245
stride = 1, reduction = 4))

0 commit comments

Comments
 (0)