Skip to content

Commit 4cdcaef

Browse files
committed
Add StochasticDepth to EfficientNets
More cleanup
1 parent 4f487a0 commit 4cdcaef

File tree

17 files changed

+85
-94
lines changed

17 files changed

+85
-94
lines changed

src/convnets/builders/irmodel.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple};
2-
inplanes::Integer = 32, connection = +, activation = relu,
3-
norm_layer = BatchNorm, divisor::Integer = 8,
4-
tail_conv::Bool = true, expanded_classifier::Bool = false,
5-
headplanes::Integer, dropout_prob = nothing,
6-
inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...)
1+
function build_irmodel(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple};
2+
inplanes::Integer = 32, connection = +, activation = relu,
3+
norm_layer = BatchNorm, divisor::Integer = 8, tail_conv::Bool = true,
4+
expanded_classifier::Bool = false, stochastic_depth_prob = nothing,
5+
headplanes::Integer, dropout_prob = nothing, inchannels::Integer = 3,
6+
nclasses::Integer = 1000, kwargs...)
77
width_mult, _ = scalings
88
# building first layer
99
inplanes = _round_channels(inplanes * width_mult, divisor)
@@ -13,7 +13,8 @@ function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector
1313
norm_layer))
1414
# building inverted residual blocks
1515
get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings;
16-
norm_layer, divisor, kwargs...)
16+
stochastic_depth_prob, norm_layer,
17+
divisor, kwargs...)
1718
append!(layers, cnn_stages(get_layers, block_repeats, connection))
1819
# building last layers
1920
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)
@@ -36,6 +37,6 @@ function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector
3637
return Chain(Chain(layers...), classifier)
3738
end
3839

39-
function irmodelbuilder(width_mult::Real, block_configs::AbstractVector{<:Tuple}; kwargs...)
40-
return irmodelbuilder((width_mult, 1), block_configs; kwargs...)
40+
function build_irmodel(width_mult::Real, block_configs::AbstractVector{<:Tuple}; kwargs...)
41+
return build_irmodel((width_mult, 1), block_configs; kwargs...)
4142
end

src/convnets/builders/mbconv.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ end
2323

2424
function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
2525
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
26-
norm_layer = BatchNorm, divisor::Integer = 8,
27-
se_from_explanes::Bool = false, kwargs...)
26+
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
27+
divisor::Integer = 8, se_from_explanes::Bool = false, kwargs...)
2828
width_mult, depth_mult = scalings
29-
block_repeats = [ceil(Int, block_configs[idx][end - 3] * depth_mult)
29+
block_repeats = [ceil(Int, block_configs[idx][end - 2] * depth_mult)
3030
for idx in eachindex(block_configs)]
3131
block_fn, k, outplanes, expansion, stride, _, reduction, activation = block_configs[stage_idx]
3232
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
@@ -37,48 +37,53 @@ function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
3737
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
3838
end
3939
outplanes = _round_channels(outplanes * width_mult, divisor)
40-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
40+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
4141
function get_layers(block_idx::Integer)
4242
inplanes = block_idx == 1 ? inplanes : outplanes
4343
explanes = _round_channels(inplanes * expansion, divisor)
4444
stride = block_idx == 1 ? stride : 1
4545
block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer,
4646
stride, reduction, kwargs...)
47-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
48-
drop_path = StochasticDepth(pathschedule(schedule_idx))
49-
return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,)
47+
use_skip = stride == 1 && inplanes == outplanes
48+
if use_skip
49+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
50+
51+
drop_path = StochasticDepth(sdschedule[schedule_idx])
52+
return (drop_path, block)
53+
else
54+
return (block,)
55+
end
5056
end
51-
return get_layers, ceil(Int, nrepeats * depth_mult)
57+
return get_layers, block_repeats[stage_idx]
5258
end
5359

5460
function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
5561
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
56-
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
57-
width_mult, depth_mult = scaling
58-
block_repeats = [ceil(Int, block_configs[idx][end - 2] * depth_mult)
62+
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
63+
divisor::Integer = 8, kwargs...)
64+
width_mult, depth_mult = scalings
65+
block_repeats = [ceil(Int, block_configs[idx][end - 1] * depth_mult)
5966
for idx in eachindex(block_configs)]
6067
block_fn, k, outplanes, expansion, stride, _, activation = block_configs[stage_idx]
6168
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
6269
outplanes = _round_channels(outplanes * width_mult, divisor)
63-
block_repeats = sum(block_configs[idx][4] for idx in 1:stage_idx)
64-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
70+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
6571
function get_layers(block_idx::Integer)
6672
inplanes = block_idx == 1 ? inplanes : outplanes
6773
explanes = _round_channels(inplanes * expansion, divisor)
6874
stride = block_idx == 1 ? stride : 1
6975
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
7076
norm_layer, stride, kwargs...)
7177
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
72-
drop_path = StochasticDepth(pathschedule(schedule_idx))
78+
drop_path = StochasticDepth(sdschedule[schedule_idx])
7379
return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,)
7480
end
7581
return get_layers, block_repeats[stage_idx]
7682
end
7783

7884
function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
7985
scalings::NTuple{2, Real}; kwargs...)
80-
bxs = [_get_builder(block_configs[1], block_configs, inplanes, idx, scalings;
81-
kwargs...)
82-
for idx in eachindex(block_configs)]
86+
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx, scalings;
87+
kwargs...) for idx in eachindex(block_configs)]
8388
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
8489
end

src/convnets/builders/resblocks.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
66
dropblock_prob = nothing, stochastic_depth_prob = nothing,
77
stride_fn = resnet_stride, planes_fn = resnet_planes,
88
downsample_tuple = (downsample_conv, downsample_identity))
9-
# DropBlock, StochasticDepth both take in rates based on a linear scaling schedule
9+
# DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule
1010
# Also get `planes_vec` needed for block `inplanes` and `planes` calculations
11-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
12-
blockschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
11+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
12+
dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
1313
planes_vec = collect(planes_fn(block_repeats))
1414
# closure over `idxs`
1515
function get_layers(stage_idx::Integer, block_idx::Integer)
16-
# DropBlock, StochasticDepth both take in rates based on a linear scaling schedule
16+
# DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule
1717
# This is also needed for block `inplanes` and `planes` calculations
1818
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
1919
planes = planes_vec[schedule_idx]
@@ -23,8 +23,8 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
2323
stride = stride_fn(stage_idx, block_idx)
2424
downsample_fn = stride != 1 || inplanes != planes * expansion ?
2525
downsample_tuple[1] : downsample_tuple[2]
26-
drop_path = StochasticDepth(pathschedule[schedule_idx])
27-
drop_block = DropBlock(blockschedule[schedule_idx])
26+
drop_path = StochasticDepth(sdschedule[schedule_idx])
27+
drop_block = DropBlock(dbschedule[schedule_idx])
2828
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
2929
norm_layer, revnorm, attn_fn, drop_path, drop_block)
3030
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
@@ -43,8 +43,8 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
4343
dropblock_prob = nothing, stochastic_depth_prob = nothing,
4444
stride_fn = resnet_stride, planes_fn = resnet_planes,
4545
downsample_tuple = (downsample_conv, downsample_identity))
46-
pathschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
47-
blockschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
46+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
47+
dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats))
4848
planes_vec = collect(planes_fn(block_repeats))
4949
# closure over `idxs`
5050
function get_layers(stage_idx::Integer, block_idx::Integer)
@@ -58,8 +58,8 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
5858
stride = stride_fn(stage_idx, block_idx)
5959
downsample_fn = stride != 1 || inplanes != planes * expansion ?
6060
downsample_tuple[1] : downsample_tuple[2]
61-
drop_path = StochasticDepth(pathschedule[schedule_idx])
62-
drop_block = DropBlock(blockschedule[schedule_idx])
61+
drop_path = StochasticDepth(sdschedule[schedule_idx])
62+
drop_block = DropBlock(dbschedule[schedule_idx])
6363
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
6464
reduction_factor, activation, norm_layer, revnorm,
6565
attn_fn, drop_path, drop_block)

src/convnets/builders/resnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
function resnetbuilder(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
2-
connection, classifier_fn)
1+
function build_resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
2+
connection, classifier_fn)
33
# Build stages of the ResNet
44
stage_blocks = cnn_stages(get_layers, block_repeats, connection)
55
backbone = Chain(stem, stage_blocks...)

src/convnets/convnext.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
5656
norm_layer = ChannelLayerNorm, revnorm = true)...))
5757
end
5858
stages = []
59-
dp_rates = linear_scheduler(stochastic_depth_prob; depth = sum(depths))
59+
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(depths))
6060
cur = 0
6161
for i in eachindex(depths)
6262
push!(stages,
63-
[convnextblock(planes[i], dp_rates[cur + j], layerscale_init)
63+
[convnextblock(planes[i], sdschedule[cur + j], layerscale_init)
6464
for j in 1:depths[i]])
6565
cur += depths[i]
6666
end

src/convnets/efficientnets/efficientnet.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
3131
:b7 => (600, (2.0, 3.1)),
3232
:b8 => (672, (2.2, 3.6)))
3333

34-
function efficientnet(config::Symbol; norm_layer = BatchNorm,
34+
function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2,
3535
dropout_prob = nothing, inchannels::Integer = 3,
3636
nclasses::Integer = 1000)
3737
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
3838
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
39-
return irmodelbuilder(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32,
40-
norm_layer, activation = swish,
41-
headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4,
42-
dropout_prob, inchannels, nclasses)
39+
return build_irmodel(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32,
40+
norm_layer, stochastic_depth_prob, activation = swish,
41+
headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4,
42+
dropout_prob, inchannels, nclasses)
4343
end
4444

4545
"""

src/convnets/efficientnets/efficientnetv2.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish),
3535
(mbconv, 3, 512, 6, 2, 32, 4, swish),
3636
(mbconv, 3, 768, 6, 1, 8, 4, swish)])
3737

38-
function efficientnetv2(config::Symbol; norm_layer = BatchNorm, dropout_prob = nothing,
39-
inchannels::Integer = 3, nclasses::Integer = 1000)
38+
function efficientnetv2(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2,
39+
dropout_prob = nothing, inchannels::Integer = 3,
40+
nclasses::Integer = 1000)
4041
_checkconfig(config, keys(EFFNETV2_CONFIGS))
4142
block_configs = EFFNETV2_CONFIGS[config]
42-
return irmodelbuilder((1, 1), block_configs; activation = swish, norm_layer,
43-
inplanes = block_configs[1][3], headplanes = 1280,
44-
dropout_prob, inchannels, nclasses)
43+
return build_irmodel((1, 1), block_configs; activation = swish, norm_layer,
44+
inplanes = block_configs[1][3], headplanes = 1280,
45+
stochastic_depth_prob, dropout_prob, inchannels, nclasses)
4546
end
4647

4748
"""

src/convnets/mobilenets/mnasnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ function mnasnet(config::Symbol; width_mult::Real = 1, max_width::Integer = 1280
4545
# momentum used for BatchNorm is as per Tensorflow implementation
4646
norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = 0.0003f0, kwargs...)
4747
inplanes, block_configs = MNASNET_CONFIGS[config]
48-
return irmodelbuilder(width_mult, block_configs; inplanes, norm_layer,
49-
headplanes = max_width, dropout_prob, inchannels, nclasses)
48+
return build_irmodel(width_mult, block_configs; inplanes, norm_layer,
49+
headplanes = max_width, dropout_prob, inchannels, nclasses)
5050
end
5151

5252
"""

src/convnets/mobilenets/mobilenetv1.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ const MOBILENETV1_CONFIGS = [
1616

1717
function mobilenetv1(width_mult::Real = 1; inplanes::Integer = 32, dropout_prob = nothing,
1818
inchannels::Integer = 3, nclasses::Integer = 1000)
19-
return irmodelbuilder(width_mult, MOBILENETV1_CONFIGS; inplanes, inchannels,
20-
activation = relu6, connection = nothing, tail_conv = false,
21-
headplanes = 1024, dropout_prob, nclasses)
19+
return build_irmodel(width_mult, MOBILENETV1_CONFIGS; inplanes, inchannels,
20+
activation = relu6, connection = nothing, tail_conv = false,
21+
headplanes = 1024, dropout_prob, nclasses)
2222
end
2323

2424
"""

src/convnets/mobilenets/mobilenetv2.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ const MOBILENETV2_CONFIGS = [
2121
function mobilenetv2(width_mult::Real = 1; max_width::Integer = 1280,
2222
inplanes::Integer = 32, dropout_prob = 0.2,
2323
inchannels::Integer = 3, nclasses::Integer = 1000)
24-
return irmodelbuilder(width_mult, MOBILENETV2_CONFIGS; activation = relu6, inplanes,
25-
headplanes = max_width, dropout_prob, inchannels, nclasses)
24+
return build_irmodel(width_mult, MOBILENETV2_CONFIGS; activation = relu6, inplanes,
25+
headplanes = max_width, dropout_prob, inchannels, nclasses)
2626
end
2727

2828
"""

0 commit comments

Comments
 (0)