Skip to content

Commit a4d3f12

Browse files
committed
Initial commit for new ResNet API
1 parent 4565b2d commit a4d3f12

File tree

8 files changed

+202
-275
lines changed

8 files changed

+202
-275
lines changed

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Pkg
22

3-
Pkg.develop(path = "..")
3+
Pkg.develop(; path = "..")
44

55
using Publish
66
using Artifacts, LazyArtifacts
@@ -13,5 +13,5 @@ p = Publish.Project(Metalhead)
1313

1414
function build_and_deploy(label)
1515
rm(label; recursive = true, force = true)
16-
deploy(Metalhead; root = "/Metalhead.jl", label = label)
16+
return deploy(Metalhead; root = "/Metalhead.jl", label = label)
1717
end

docs/serve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Pkg
22

3-
Pkg.develop(path = "..")
3+
Pkg.develop(; path = "..")
44

55
using Revise
66
using Publish

src/Metalhead.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ include("vit-based/vit.jl")
3838
include("pretrain.jl")
3939

4040
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
41-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
41+
# ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
4242
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
4343
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
4444
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3,
@@ -47,7 +47,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
4747
ConvMixer, ConvNeXt
4848

4949
# use Flux._big_show to pretty print large models
50-
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet,
50+
for T in (:AlexNet, :VGG, :ResNeXt, :DenseNet, # :ResNet,
5151
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
5252
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
5353
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)

src/convnets/densenet.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ Create a DenseNet model
100100
- `reduction`: the factor by which the number of feature maps is scaled across each transition
101101
- `nclasses`: the number of output classes
102102
"""
103-
function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5, nclasses = 1000)
103+
function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5,
104+
nclasses = 1000) where {N}
104105
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
105106
reduction = reduction, nclasses = nclasses)
106107
end

src/convnets/resnet.jl

Lines changed: 167 additions & 245 deletions
Large diffs are not rendered by default.

src/layers/Layers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ export MHAttention,
2424
ChannelLayerNorm, prenorm,
2525
skip_identity, skip_projection,
2626
conv_bn, depthwise_sep_conv_bn,
27-
invertedresidual, squeeze_excite
27+
invertedresidual, squeeze_excite,
28+
DropBlock
2829
end

src/layers/drop.jl

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,45 +7,48 @@ Implements DropBlock, a regularization method for convolutional networks.
77
struct DropBlock{F}
88
drop_prob::F
99
block_size::Integer
10+
gamma_scale::F
1011
end
1112
@functor DropBlock
1213

13-
(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size)
14+
(m::DropBlock)(x) = dropblock(x, m.drop_prob, m.block_size, m.gamma_scale)
1415

15-
DropBlock(drop_prob = 0.1, block_size = 7) = DropBlock(drop_prob, block_size)
16+
function DropBlock(drop_prob = 0.1, block_size = 7, gamma_scale = 1.0)
17+
return DropBlock(drop_prob, block_size, gamma_scale)
18+
end
1619

17-
function _dropblock_checks(x, drop_prob, T)
20+
function _dropblock_checks(x, drop_prob, gamma_scale, T)
1821
if !(T <: AbstractArray)
1922
throw(ArgumentError("x must be an `AbstractArray`"))
2023
end
2124
if ndims(x) != 4
2225
throw(ArgumentError("x must have 4 dimensions (H, W, C, N) for `DropBlock`"))
2326
end
24-
@assert drop_prob < 0 || drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob"
27+
@assert drop_prob < 0||drop_prob > 1 "drop_prob must be between 0 and 1, got $drop_prob"
28+
@assert gamma_scale < 0||gamma_scale > 1 "gamma_scale must be between 0 and 1, got $gamma_scale"
2529
end
26-
ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, T)
30+
ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_prob, gamma_scale, T)
2731

28-
function dropblock(x::T, drop_prob, block_size::Integer) where {T}
29-
_dropblock_checks(x, drop_prob, T)
32+
function dropblock(x::T, drop_prob, block_size::Integer, gamma_scale) where {T}
33+
_dropblock_checks(x, drop_prob, gamma_scale, T)
3034
if drop_prob == 0
3135
return x
3236
end
33-
return _dropblock(x, drop_prob, block_size)
37+
return _dropblock(x, drop_prob, block_size, gamma_scale)
3438
end
3539

36-
function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size) where {T}
37-
gamma = drop_prob / (block_size ^ 2)
38-
mask = rand_like(x, Float32, (size(x, 1), size(x, 2), size(x, 3)))
39-
mask .<= gamma
40-
block_mask = maxpool(reshape(mask, (size(mask)[1:3]..., 1)), (block_size, block_size);
41-
pad = block_size ÷ 2, stride = (1, 1))
42-
if block_size % 2 == 0
43-
block_mask = block_mask[1:(end - 1), 1:(end - 1), :, :]
44-
end
45-
block_mask = 1 .- dropdims(block_mask; dims = 4)
46-
out = (x .* reshape(block_mask, (size(block_mask)[1:3]..., 1))) * length(block_mask) /
47-
sum(block_mask)
48-
return out
40+
function _dropblock(x::AbstractArray{T, 4}, drop_prob, block_size, gamma_scale) where {T}
41+
H, W, _, _ = size(x)
42+
total_size = H * W
43+
clipped_block_size = min(block_size, min(H, W))
44+
gamma = gamma_scale * drop_prob * total_size / clipped_block_size^2 /
45+
((W - block_size + 1) * (H - block_size + 1))
46+
block_mask = rand_like(x) .< gamma
47+
block_mask = maxpool(convert(T, block_mask), (clipped_block_size, clipped_block_size);
48+
stride = 1, padding = clipped_block_size ÷ 2)
49+
block_mask = 1 .- block_mask
50+
normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6))
51+
return x * block_mask * normalize_scale
4952
end
5053

5154
"""

src/layers/normalise.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ end
1919

2020
@functor ChannelLayerNorm
2121

22-
(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))
23-
2422
function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5)
2523
diag = Flux.Scale(1, 1, sz, λ)
2624
return ChannelLayerNorm(diag, ϵ)
2725
end
26+
27+
(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))

0 commit comments

Comments
 (0)