Skip to content

Commit a1d5ddc

Browse files
committed
Make DropBlock really work
1 parent 7846f8b commit a1d5ddc

File tree

5 files changed

+57
-43
lines changed

5 files changed

+57
-43
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ version = "0.7.3"
55
[deps]
66
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
8+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1011
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1112
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
1213
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1314
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
14-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718

1819
[compat]

src/Metalhead.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ include("vit-based/vit.jl")
3939
include("pretrain.jl")
4040

4141
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
42-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
42+
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt,
4343
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
4444
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
4545
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
@@ -48,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
4848
ConvMixer, ConvNeXt
4949

5050
# use Flux._big_show to pretty print large models
51-
for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, :ResNet,
51+
for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt,
5252
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
5353
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
5454
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)

src/convnets/resne(x)t.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool =
107107
else
108108
stempool = MaxPool((3, 3); stride = 2, pad = 1)
109109
end
110-
return Chain(conv1, bn1, stempool)
110+
return inplanes, Chain(conv1, bn1, stempool)
111111
end
112112

113113
function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1),
@@ -150,7 +150,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
150150
end
151151
# Downsample block; either a (default) convolution-based block or a pooling-based block.
152152
downsample = downsample_block(downsample_fn, inplanes, planes, expansion;
153-
downsample_args...)
153+
stride, dilation, first_dilation = dilation, downsample_args...)
154154
# Construct the blocks for each stage
155155
blocks = []
156156
for block_idx in 1:num_blocks
@@ -172,16 +172,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
172172
end
173173

174174
function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32,
175-
stem_fn = resnet_stem, stem_args::NamedTuple = (),
176-
downsample_fn = downsample_conv, downsample_args::NamedTuple = (),
175+
stem_fn = resnet_stem, stem_args::NamedTuple = NamedTuple(),
176+
downsample_fn = downsample_conv, downsample_args::NamedTuple = NamedTuple(),
177177
drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0,
178-
drop_block_rate = 0.0),
179-
block_args::NamedTuple = ())
178+
drop_block_rate = 0.5),
179+
block_args::NamedTuple = NamedTuple())
180180
# Stem
181-
stem = stem_fn(; inchannels, stem_args...)
181+
inplanes, stem = stem_fn(; inchannels, stem_args...)
182182
# Feature Blocks
183183
channels = [64, 128, 256, 512]
184-
stage_blocks = _make_blocks(block, channels, layers, inchannels;
184+
stage_blocks = _make_blocks(block, channels, layers, inplanes;
185185
output_stride, downsample_fn, downsample_args,
186186
drop_block_rate = drop_rates.drop_block_rate,
187187
drop_path_rate = drop_rates.drop_path_rate,

src/layers/Layers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
module Layers
22

33
using Flux
4+
using CUDA
45
using NNlib
56
using NNlibCUDA
67
using Functors
78
using ChainRulesCore
89
using Statistics
910
using MLUtils
11+
using Random
1012

1113
include("../utilities.jl")
1214

src/layers/drop.jl

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
1-
"""
2-
DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0)
1+
function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size,
2+
gamma_scale, active::Bool = true) where {T}
3+
active || return x
4+
H, W, _, _ = size(x)
5+
total_size = H * W
6+
clipped_block_size = min(block_size, min(H, W))
7+
gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 /
8+
((W - block_size + 1) * (H - block_size + 1))
9+
block_mask = rand_like(rng, x) .< gamma
10+
block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size);
11+
stride = 1, pad = clipped_block_size ÷ 2)
12+
block_mask = 1 .- block_mask
13+
normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6))
14+
return x .* block_mask .* normalize_scale
15+
end
16+
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
17+
function dropblock(rng, x::CuArray, p; kwargs...)
18+
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays."))
19+
end
320

4-
Implements DropBlock, a regularization method for convolutional networks.
5-
([reference](https://arxiv.org/pdf/1810.12890.pdf))
6-
"""
7-
struct DropBlock{F}
21+
struct DropBlock{F, R <: AbstractRNG}
822
drop_block_prob::F
923
block_size::Integer
1024
gamma_scale::F
25+
active::Union{Bool, Nothing}
26+
rng::R
1127
end
12-
@functor DropBlock
13-
14-
(m::DropBlock)(x) = dropblock(x, m.drop_block_prob, m.block_size, m.gamma_scale)
1528

16-
function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0)
17-
if drop_block_prob == 0.0
18-
return identity
19-
end
20-
@assert drop_block_prob < 0 || drop_block_prob > 1
21-
"drop_block_prob must be between 0 and 1, got $drop_block_prob"
22-
@assert gamma_scale < 0 || gamma_scale > 1
23-
"gamma_scale must be between 0 and 1, got $gamma_scale"
24-
return DropBlock(drop_block_prob, block_size, gamma_scale)
25-
end
29+
@functor DropBlock
30+
trainable(a::DropBlock) = (;)
2631

2732
function _dropblock_checks(x::T) where {T}
2833
if !(T <: AbstractArray)
@@ -34,20 +39,26 @@ function _dropblock_checks(x::T) where {T}
3439
end
3540
ChainRulesCore.@non_differentiable _dropblock_checks(x)
3641

37-
function dropblock(x::AbstractArray{T, 4}, drop_block_prob, block_size,
38-
gamma_scale) where {T}
42+
function (m::DropBlock)(x)
3943
_dropblock_checks(x)
40-
H, W, _, _ = size(x)
41-
total_size = H * W
42-
clipped_block_size = min(block_size, min(H, W))
43-
gamma = gamma_scale * drop_block_prob * total_size / clipped_block_size^2 /
44-
((W - block_size + 1) * (H - block_size + 1))
45-
block_mask = rand_like(x) .< gamma
46-
block_mask = maxpool(block_mask, (clipped_block_size, clipped_block_size);
47-
stride = 1, pad = clipped_block_size ÷ 2)
48-
block_mask = 1 .- block_mask
49-
normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6))
50-
return x .* block_mask .* normalize_scale
44+
Flux._isactive(m) || return x
45+
return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale)
46+
end
47+
48+
function Flux.testmode!(m::DropBlock, mode = true)
49+
return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
50+
end
51+
52+
function DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0,
53+
rng = Flux.rng_from_array())
54+
if drop_block_prob == 0.0
55+
return identity
56+
end
57+
@assert 0 drop_block_prob 1
58+
"drop_block_prob must be between 0 and 1, got $drop_block_prob"
59+
@assert 0 gamma_scale 1
60+
"gamma_scale must be between 0 and 1, got $gamma_scale"
61+
return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng)
5162
end
5263

5364
"""

0 commit comments

Comments
 (0)