Skip to content

Commit 97c3911

Browse files
committed
Add StochasticDepth to EfficientNets
More cleanup
1 parent 4f487a0 commit 97c3911

File tree

11 files changed

+60
-68
lines changed

11 files changed

+60
-68
lines changed

src/convnets/builders/irmodel.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector
22
inplanes::Integer = 32, connection = +, activation = relu,
33
norm_layer = BatchNorm, divisor::Integer = 8,
44
tail_conv::Bool = true, expanded_classifier::Bool = false,
5-
headplanes::Integer, dropout_prob = nothing,
6-
inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...)
5+
stochastic_depth_prob = nothing, headplanes::Integer,
6+
dropout_prob = nothing, inchannels::Integer = 3,
7+
nclasses::Integer = 1000, kwargs...)
78
width_mult, _ = scalings
89
# building first layer
910
inplanes = _round_channels(inplanes * width_mult, divisor)
@@ -13,7 +14,8 @@ function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector
1314
norm_layer))
1415
# building inverted residual blocks
1516
get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings;
16-
norm_layer, divisor, kwargs...)
17+
stochastic_depth_prob, norm_layer,
18+
divisor, kwargs...)
1719
append!(layers, cnn_stages(get_layers, block_repeats, connection))
1820
# building last layers
1921
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)

src/convnets/builders/mbconv.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ end
2323

2424
function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
2525
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
26-
norm_layer = BatchNorm, divisor::Integer = 8,
27-
se_from_explanes::Bool = false, kwargs...)
26+
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
27+
divisor::Integer = 8, se_from_explanes::Bool = false, kwargs...)
2828
width_mult, depth_mult = scalings
29-
block_repeats = [ceil(Int, block_configs[idx][end - 3] * depth_mult)
29+
block_repeats = [ceil(Int, block_configs[idx][end - 2] * depth_mult)
3030
for idx in eachindex(block_configs)]
3131
block_fn, k, outplanes, expansion, stride, _, reduction, activation = block_configs[stage_idx]
3232
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
@@ -37,48 +37,53 @@ function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
3737
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
3838
end
3939
outplanes = _round_channels(outplanes * width_mult, divisor)
40-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
40+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
4141
function get_layers(block_idx::Integer)
4242
inplanes = block_idx == 1 ? inplanes : outplanes
4343
explanes = _round_channels(inplanes * expansion, divisor)
4444
stride = block_idx == 1 ? stride : 1
4545
block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer,
4646
stride, reduction, kwargs...)
47-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
48-
drop_path = StochasticDepth(pathschedule(schedule_idx))
49-
return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,)
47+
use_skip = stride == 1 && inplanes == outplanes
48+
if use_skip
49+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
50+
51+
drop_path = StochasticDepth(sdschedule[schedule_idx])
52+
return (drop_path, block)
53+
else
54+
return (block,)
55+
end
5056
end
51-
return get_layers, ceil(Int, nrepeats * depth_mult)
57+
return get_layers, block_repeats[stage_idx]
5258
end
5359

5460
function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
5561
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
56-
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
57-
width_mult, depth_mult = scaling
58-
block_repeats = [ceil(Int, block_configs[idx][end - 2] * depth_mult)
62+
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
63+
divisor::Integer = 8, kwargs...)
64+
width_mult, depth_mult = scalings
65+
block_repeats = [ceil(Int, block_configs[idx][end - 1] * depth_mult)
5966
for idx in eachindex(block_configs)]
6067
block_fn, k, outplanes, expansion, stride, _, activation = block_configs[stage_idx]
6168
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
6269
outplanes = _round_channels(outplanes * width_mult, divisor)
63-
block_repeats = sum(block_configs[idx][4] for idx in 1:stage_idx)
64-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
70+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
6571
function get_layers(block_idx::Integer)
6672
inplanes = block_idx == 1 ? inplanes : outplanes
6773
explanes = _round_channels(inplanes * expansion, divisor)
6874
stride = block_idx == 1 ? stride : 1
6975
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
7076
norm_layer, stride, kwargs...)
7177
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
72-
drop_path = StochasticDepth(pathschedule(schedule_idx))
78+
drop_path = StochasticDepth(sdschedule[schedule_idx])
7379
return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,)
7480
end
7581
return get_layers, block_repeats[stage_idx]
7682
end
7783

7884
function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
7985
scalings::NTuple{2, Real}; kwargs...)
80-
bxs = [_get_builder(block_configs[1], block_configs, inplanes, idx, scalings;
81-
kwargs...)
82-
for idx in eachindex(block_configs)]
86+
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx, scalings;
87+
kwargs...) for idx in eachindex(block_configs)]
8388
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
8489
end

src/convnets/builders/resblocks.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
66
dropblock_prob = nothing, stochastic_depth_prob = nothing,
77
stride_fn = resnet_stride, planes_fn = resnet_planes,
88
downsample_tuple = (downsample_conv, downsample_identity))
9-
# DropBlock, StochasticDepth both take in rates based on a linear scaling schedule
9+
# DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule
1010
# Also get `planes_vec` needed for block `inplanes` and `planes` calculations
11-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
12-
blockschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
11+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
12+
dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
1313
planes_vec = collect(planes_fn(block_repeats))
1414
# closure over `idxs`
1515
function get_layers(stage_idx::Integer, block_idx::Integer)
16-
# DropBlock, StochasticDepth both take in rates based on a linear scaling schedule
16+
# DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule
1717
# This is also needed for block `inplanes` and `planes` calculations
1818
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
1919
planes = planes_vec[schedule_idx]
@@ -23,8 +23,8 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
2323
stride = stride_fn(stage_idx, block_idx)
2424
downsample_fn = stride != 1 || inplanes != planes * expansion ?
2525
downsample_tuple[1] : downsample_tuple[2]
26-
drop_path = StochasticDepth(pathschedule[schedule_idx])
27-
drop_block = DropBlock(blockschedule[schedule_idx])
26+
drop_path = StochasticDepth(sdschedule[schedule_idx])
27+
drop_block = DropBlock(dbschedule[schedule_idx])
2828
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
2929
norm_layer, revnorm, attn_fn, drop_path, drop_block)
3030
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
@@ -43,8 +43,8 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
4343
dropblock_prob = nothing, stochastic_depth_prob = nothing,
4444
stride_fn = resnet_stride, planes_fn = resnet_planes,
4545
downsample_tuple = (downsample_conv, downsample_identity))
46-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
47-
blockschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
46+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
47+
dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
4848
planes_vec = collect(planes_fn(block_repeats))
4949
# closure over `idxs`
5050
function get_layers(stage_idx::Integer, block_idx::Integer)
@@ -58,8 +58,8 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
5858
stride = stride_fn(stage_idx, block_idx)
5959
downsample_fn = stride != 1 || inplanes != planes * expansion ?
6060
downsample_tuple[1] : downsample_tuple[2]
61-
drop_path = StochasticDepth(pathschedule[schedule_idx])
62-
drop_block = DropBlock(blockschedule[schedule_idx])
61+
drop_path = StochasticDepth(sdschedule[schedule_idx])
62+
drop_block = DropBlock(dbschedule[schedule_idx])
6363
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
6464
reduction_factor, activation, norm_layer, revnorm,
6565
attn_fn, drop_path, drop_block)

src/convnets/convnext.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
5656
norm_layer = ChannelLayerNorm, revnorm = true)...))
5757
end
5858
stages = []
59-
dp_rates = linear_scheduler(stochastic_depth_prob; depth = sum(depths))
59+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(depths))
6060
cur = 0
6161
for i in eachindex(depths)
6262
push!(stages,
63-
[convnextblock(planes[i], dp_rates[cur + j], layerscale_init)
63+
[convnextblock(planes[i], sdschedule[cur + j], layerscale_init)
6464
for j in 1:depths[i]])
6565
cur += depths[i]
6666
end

src/convnets/efficientnets/efficientnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
3131
:b7 => (600, (2.0, 3.1)),
3232
:b8 => (672, (2.2, 3.6)))
3333

34-
function efficientnet(config::Symbol; norm_layer = BatchNorm,
34+
function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2,
3535
dropout_prob = nothing, inchannels::Integer = 3,
3636
nclasses::Integer = 1000)
3737
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
3838
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
3939
return irmodelbuilder(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32,
40-
norm_layer, activation = swish,
40+
norm_layer, stochastic_depth_prob, activation = swish,
4141
headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4,
4242
dropout_prob, inchannels, nclasses)
4343
end

src/convnets/efficientnets/efficientnetv2.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish),
3535
(mbconv, 3, 512, 6, 2, 32, 4, swish),
3636
(mbconv, 3, 768, 6, 1, 8, 4, swish)])
3737

38-
function efficientnetv2(config::Symbol; norm_layer = BatchNorm, dropout_prob = nothing,
39-
inchannels::Integer = 3, nclasses::Integer = 1000)
38+
function efficientnetv2(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2,
39+
dropout_prob = nothing, inchannels::Integer = 3,
40+
nclasses::Integer = 1000)
4041
_checkconfig(config, keys(EFFNETV2_CONFIGS))
4142
block_configs = EFFNETV2_CONFIGS[config]
4243
return irmodelbuilder((1, 1), block_configs; activation = swish, norm_layer,
4344
inplanes = block_configs[1][3], headplanes = 1280,
44-
dropout_prob, inchannels, nclasses)
45+
stochastic_depth_prob, dropout_prob, inchannels, nclasses)
4546
end
4647

4748
"""

src/layers/drop.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,6 @@ It can be used in two ways: either with all blocks having the same survival prob
6262
or with a linear scaling rule across the blocks. This is performed only at training time.
6363
At test time, the `DropBlock` layer is equivalent to `identity`.
6464
65-
!!! warning
66-
67-
In the case of the linear scaling rule, the calculations of survival probabilities for each
68-
block may lead to a survival probability > 1 for a given block. This will lead to
69-
`DropBlock` erroring. This usually happens with a low number of blocks and a high base
70-
survival probability, so in such cases it is recommended to use a fixed base survival
71-
probability across blocks. If this is not desired, then a lower base survival probability
72-
is recommended.
73-
7465
([reference](https://arxiv.org/abs/1810.12890))
7566
7667
# Arguments
@@ -141,15 +132,6 @@ all blocks having the same survival probability or with a linear scaling rule ac
141132
blocks. This is performed only at training time. At test time, the `StochasticDepth` layer is
142133
equivalent to `identity`.
143134
144-
!!! warning
145-
146-
In the case of the linear scaling rule, the calculations of survival probabilities for each
147-
block may lead to a survival probability > 1 for a given block. This will lead to
148-
`StochasticDepth` erroring. This usually happens with a low number of blocks and a high base
149-
survival probability, so in such cases it is recommended to use a fixed base survival
150-
probability across blocks. If this is not desired, then a lower base survival probability
151-
is recommended.
152-
153135
# Arguments
154136
155137
- `p`: probability of Stochastic Depth.

src/layers/mlp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# TODO @theabhirath figure out consistent behaviour for dropout rates - 0.0 vs `nothing`
1+
# TODO @theabhirath figure out consistent behaviour for dropout probs - 0.0 vs `nothing`
22
"""
33
mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
44
dropout_prob = 0., activation = gelu)

src/mixers/core.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); norm_layer = LayerNorm,
2727
depth::Integer = 12, inchannels::Integer = 3, nclasses::Integer = 1000,
2828
kwargs...)
2929
npatches = prod(imsize patch_size)
30-
dp_rates = linear_scheduler(stochastic_depth_prob; depth)
30+
sdschedule = linear_scheduler(stochastic_depth_prob; depth)
3131
layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes),
32-
Chain([block(embedplanes, npatches; stochastic_depth_prob = dp_rates[i],
32+
Chain([block(embedplanes, npatches;
33+
stochastic_depth_prob = sdschedule[i],
3334
kwargs...)
3435
for i in 1:depth]...))
3536
classifier = Chain(norm_layer(embedplanes), seconddimmean, Dense(embedplanes, nclasses))

src/utilities.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,18 @@ function _maybe_big_show(io, model)
6666
end
6767

6868
"""
69-
linear_scheduler(drop_rate = 0.0; start_value = 0.0, depth)
70-
linear_scheduler(drop_rate::Nothing; depth::Integer)
69+
linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth)
70+
linear_scheduler(drop_prob::Nothing; depth::Integer)
7171
72-
Returns the dropout rates for a given depth using the linear scaling rule. If the
73-
`drop_rate` is `nothing`, it returns a `Vector` of length `depth` with all values
74-
equal to `nothing`.
72+
Returns the dropout probabilities for a given depth using the linear scaling rule. Note
73+
that this returns evenly spaced values between `start_value` and `drop_prob`, not including
74+
`drop_prob`. If `drop_prob` is `nothing`, it returns a `Vector` of length `depth` with all
75+
values equal to `nothing`.
7576
"""
76-
function linear_scheduler(drop_rate = 0.0; depth::Integer, start_value = 0.0)
77-
return LinRange(start_value, drop_rate, depth)
77+
function linear_scheduler(drop_prob = 0.0; depth::Integer, start_value = 0.0)
78+
return LinRange(start_value, drop_prob, depth + 1)[1:(end - 1)]
7879
end
79-
linear_scheduler(drop_rate::Nothing; depth::Integer) = fill(drop_rate, depth)
80+
linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth)
8081

8182
# Utility function for depth and configuration checks in models
8283
function _checkconfig(config, configs)

0 commit comments

Comments
 (0)