Skip to content

Commit ccb54da

Browse files
Cleanup - docs and code
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
1 parent b143b95 commit ccb54da

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

src/convnets/resnets/core.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,19 @@
44
drop_block = identity, drop_path = identity,
55
attn_fn = planes -> identity)
66
7-
Creates a basic ResNet block.
7+
Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
88
99
# Arguments
1010
1111
- `inplanes`: number of input feature maps
1212
- `planes`: number of feature maps for the block
1313
- `stride`: the stride of the block
14-
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
15-
convolution.
14+
- `reduction_factor`: the factor by which the input feature maps
15+
are reduced before the first convolution.
1616
- `activation`: the activation function to use.
1717
- `norm_layer`: the normalization layer to use.
18-
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
19-
function and passed in.
20-
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
21-
function and passed in.
18+
- `drop_block`: the drop block layer
19+
- `drop_path`: the drop path layer
2220
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
2321
"""
2422
function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
@@ -36,7 +34,6 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3634
drop_path]
3735
return Chain(filter!(!=(identity), layers)...)
3836
end
39-
expansion_factor(::typeof(basicblock)) = 1
4037

4138
"""
4239
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
@@ -45,7 +42,7 @@ expansion_factor(::typeof(basicblock)) = 1
4542
drop_block = identity, drop_path = identity,
4643
attn_fn = planes -> identity)
4744
48-
Creates a bottleneck ResNet block.
45+
Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
4946
5047
# Arguments
5148
@@ -58,10 +55,8 @@ Creates a bottleneck ResNet block.
5855
convolution.
5956
- `activation`: the activation function to use.
6057
- `norm_layer`: the normalization layer to use.
61-
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
62-
function and passed in.
63-
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
64-
function and passed in.
58+
- `drop_block`: the drop block layer
59+
- `drop_path`: the drop path layer
6560
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
6661
"""
6762
function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
@@ -83,7 +78,6 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
8378
attn_fn(outplanes), drop_path]
8479
return Chain(filter!(!=(identity), layers)...)
8580
end
86-
expansion_factor(::typeof(bottleneck)) = 4
8781

8882
# Downsample layer using convolutions.
8983
function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1,
@@ -124,7 +118,7 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
124118
:D => (downsample_pool, downsample_identity))
125119

126120
# Stride for each block in the ResNet model
127-
function get_stride(block_idx::Integer, stage_idx::Integer)
121+
function resnet_stride(stage_idx::Integer, block_idx::Integer)
128122
return (stage_idx == 1 || block_idx != 1) ? 1 : 2
129123
end
130124

@@ -159,8 +153,7 @@ on how to use this function.
159153
shows peformance improvements over the `:deep` stem in some cases.
160154
161155
- `inchannels`: The number of channels in the input.
162-
- `replace_pool`: Whether to replace the default 3x3 max pooling layer with a
163-
3x3 convolution with stride 2 and a normalisation layer.
156+
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
164157
- `norm_layer`: The normalisation layer used in the stem.
165158
- `activation`: The activation function used in the stem.
166159
"""
@@ -270,7 +263,7 @@ end
270263
function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
271264
# Construct each stage
272265
stages = []
273-
for (stage_idx, (num_blocks)) in enumerate(block_repeats)
266+
for (stage_idx, num_blocks) in enumerate(block_repeats)
274267
# Construct the blocks for each stage
275268
blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...)
276269
for block_idx in range(1, num_blocks)]
@@ -307,6 +300,7 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
307300
stride_fn = get_stride, planes_fn = resnet_planes,
308301
downsample_tuple = downsample_opt)
309302
else
303+
# TODO: write better message when we have link to dev docs for resnet
310304
throw(ArgumentError("Unknown block type $block_type"))
311305
end
312306
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
@@ -318,7 +312,7 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
318312
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)
319313
end
320314

321-
function resnet(img_dims, stem, connection, get_layers, block_repeats::Vector{<:Integer},
315+
function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection,
322316
classifier_fn)
323317
# Build stages of the ResNet
324318
stage_blocks = resnet_stages(get_layers, block_repeats, connection)

0 commit comments

Comments
 (0)