Skip to content

Commit 3be1d81

Browse files
committed
Construct the stem outside and pass it into resnet
`downsample_args` is actually redundant
1 parent a1d5ddc commit 3be1d81

File tree

2 files changed

+29
-33
lines changed

2 files changed

+29
-33
lines changed

src/Metalhead.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using BSON
77
using Artifacts, LazyArtifacts
88
using Statistics
99
using MLUtils
10-
using Random
1110

1211
import Functors
1312

src/convnets/resne(x)t.jl

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
# returns `DropBlock`s for each block of the ResNet
2-
function _drop_blocks(drop_block_prob = 0.0)
3-
return [
4-
identity,
5-
identity,
6-
DropBlock(drop_block_prob, 5, 0.25),
7-
DropBlock(drop_block_prob, 3, 1.00),
8-
]
9-
end
10-
111
function downsample_conv(kernel_size, inplanes, outplanes; stride = 1, dilation = 1,
122
first_dilation = nothing, norm_layer = BatchNorm)
133
kernel_size = stride == 1 && dilation == 1 ? (1, 1) : kernel_size
@@ -80,12 +70,15 @@ expansion_factor(::typeof(bottleneck)) = 4
8070

8171
function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool = false,
8272
norm_layer = BatchNorm, activation = relu)
83-
@assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]"
73+
@assert stem_type in [:default, :deep, :deep_tiered]
74+
"Stem type must be one of [:default, :deep, :deep_tiered]"
8475
# Main stem
85-
inplanes = stem_type == :deep ? stem_width * 2 : 64
86-
if stem_type == :deep
87-
stem_channels = (stem_width, stem_width)
88-
if stem_type == :deep_tiered
76+
deep_stem = stem_type == :deep || stem_type == :deep_tiered
77+
inplanes = deep_stem ? stem_width * 2 : 64
78+
if deep_stem
79+
if stem_type == :deep
80+
stem_channels = (stem_width, stem_width)
81+
elseif stem_type == :deep_tiered
8982
stem_channels = (3 * (stem_width ÷ 4), stem_width)
9083
end
9184
conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1,
@@ -107,7 +100,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool =
107100
else
108101
stempool = MaxPool((3, 3); stride = 2, pad = 1)
109102
end
110-
return inplanes, Chain(conv1, bn1, stempool)
103+
return Chain(conv1, bn1, stempool), inplanes
111104
end
112105

113106
function downsample_block(downsample_fn, inplanes, planes, expansion; kernel_size = (1, 1),
@@ -128,9 +121,8 @@ end
128121
# See `basicblock` and `bottleneck` for examples. A block must define a function
129122
# `expansion(::typeof(block))` that returns the expansion factor of the block.
130123
function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride = 32,
131-
downsample_fn = downsample_conv, downsample_args::NamedTuple = (),
132-
drop_block_rate = 0.0, drop_path_rate = 0.0,
133-
block_args::NamedTuple = ())
124+
downsample_fn = downsample_conv,
125+
drop_rates::NamedTuple, block_args::NamedTuple)
134126
@assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)"
135127
expansion = expansion_factor(block_fn)
136128
stages = []
@@ -139,7 +131,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
139131
dilation = prev_dilation = 1
140132
for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels,
141133
block_repeats,
142-
_drop_blocks(drop_block_rate)))
134+
_drop_blocks(drop_rates.drop_block_rate)))
143135
# Stride calculations for each stage
144136
stride = stage_idx == 1 ? 1 : 2
145137
if net_stride >= output_stride
@@ -148,16 +140,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
148140
else
149141
net_stride *= stride
150142
end
151-
# Downsample block; either a (default) convolution-based block or a pooling-based block.
143+
# Downsample block; either a (default) convolution-based block or a pooling-based block
152144
downsample = downsample_block(downsample_fn, inplanes, planes, expansion;
153-
stride, dilation, first_dilation = dilation, downsample_args...)
145+
stride, dilation, first_dilation = dilation)
154146
# Construct the blocks for each stage
155147
blocks = []
156148
for block_idx in 1:num_blocks
157149
downsample = block_idx == 1 ? downsample : identity
158150
stride = block_idx == 1 ? stride : 1
159151
# stochastic depth linear decay rule
160-
block_dpr = drop_path_rate * net_block_idx / (sum(block_repeats) - 1)
152+
block_dpr = drop_rates.drop_path_rate * net_block_idx / (sum(block_repeats) - 1)
161153
push!(blocks,
162154
block_fn(inplanes, planes; stride, downsample,
163155
first_dilation = prev_dilation,
@@ -171,21 +163,26 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
171163
return Chain(stages...)
172164
end
173165

166+
# returns `DropBlock`s for each block of the ResNet
167+
function _drop_blocks(drop_block_prob = 0.0)
168+
return [
169+
identity,
170+
identity,
171+
DropBlock(drop_block_prob, 5, 0.25),
172+
DropBlock(drop_block_prob, 3, 1.00),
173+
]
174+
end
175+
174176
function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32,
175-
stem_fn = resnet_stem, stem_args::NamedTuple = NamedTuple(),
176-
downsample_fn = downsample_conv, downsample_args::NamedTuple = NamedTuple(),
177+
stem = first(resnet_stem(; inchannels)), inplanes = 64,
178+
downsample_fn = downsample_conv,
177179
drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0,
178-
drop_block_rate = 0.5),
180+
drop_block_rate = 0.0),
179181
block_args::NamedTuple = NamedTuple())
180-
# Stem
181-
inplanes, stem = stem_fn(; inchannels, stem_args...)
182182
# Feature Blocks
183183
channels = [64, 128, 256, 512]
184184
stage_blocks = _make_blocks(block, channels, layers, inplanes;
185-
output_stride, downsample_fn, downsample_args,
186-
drop_block_rate = drop_rates.drop_block_rate,
187-
drop_path_rate = drop_rates.drop_path_rate,
188-
block_args)
185+
output_stride, downsample_fn, drop_rates, block_args)
189186
# Head (Pooling and Classifier)
190187
expansion = expansion_factor(block)
191188
num_features = 512 * expansion

0 commit comments

Comments
 (0)