Skip to content

Commit 7846f8b

Browse files
committed
More declarative interface for ResNet
1. Less keywords for the user to worry about 2. Delete `ResNeXt` just for now
1 parent 4fa28d4 commit 7846f8b

File tree

3 files changed

+134
-294
lines changed

3 files changed

+134
-294
lines changed

src/convnets/efficientnet.jl

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
66
77
# Arguments
88
9-
- `scalings`: global width and depth scaling (given as a tuple)
10-
- `block_config`: configuration for each inverted residual block,
11-
given as a vector of tuples with elements:
12-
- `n`: number of block repetitions (will be scaled by global depth scaling)
13-
- `k`: kernel size
14-
- `s`: kernel stride
15-
- `e`: expansion ratio
16-
- `i`: block input channels (will be scaled by global width scaling)
17-
- `o`: block output channels (will be scaled by global width scaling)
18-
- `inchannels`: number of input channels
19-
- `nclasses`: number of output classes
20-
- `max_width`: maximum number of output channels before the fully connected
21-
classification blocks
9+
- `scalings`: global width and depth scaling (given as a tuple)
10+
11+
- `block_config`: configuration for each inverted residual block,
12+
given as a vector of tuples with elements:
13+
14+
+ `n`: number of block repetitions (will be scaled by global depth scaling)
15+
+ `k`: kernel size
16+
+ `s`: kernel stride
17+
+ `e`: expansion ratio
18+
+ `i`: block input channels (will be scaled by global width scaling)
19+
+ `o`: block output channels (will be scaled by global width scaling)
20+
- `inchannels`: number of input channels
21+
- `nclasses`: number of output classes
22+
- `max_width`: maximum number of output channels before the fully connected
23+
classification blocks
2224
"""
2325
function efficientnet(scalings, block_config;
2426
inchannels = 3, nclasses = 1000, max_width = 1280)
@@ -64,34 +66,33 @@ end
6466
# i: block input channels
6567
# o: block output channels
6668
const efficientnet_block_configs = [
67-
# (n, k, s, e, i, o)
68-
(1, 3, 1, 1, 32, 16),
69-
(2, 3, 2, 6, 16, 24),
70-
(2, 5, 2, 6, 24, 40),
71-
(3, 3, 2, 6, 40, 80),
72-
(3, 5, 1, 6, 80, 112),
69+
# (n, k, s, e, i, o)
70+
(1, 3, 1, 1, 32, 16),
71+
(2, 3, 2, 6, 16, 24),
72+
(2, 5, 2, 6, 24, 40),
73+
(3, 3, 2, 6, 40, 80),
74+
(3, 5, 1, 6, 80, 112),
7375
(4, 5, 2, 6, 112, 192),
74-
(1, 3, 1, 6, 192, 320)
76+
(1, 3, 1, 6, 192, 320),
7577
]
7678

7779
# w: width scaling
7880
# d: depth scaling
7981
# r: image resolution
8082
const efficientnet_global_configs = Dict(
81-
# ( r, ( w, d))
82-
:b0 => (224, (1.0, 1.0)),
83-
:b1 => (240, (1.0, 1.1)),
84-
:b2 => (260, (1.1, 1.2)),
85-
:b3 => (300, (1.2, 1.4)),
86-
:b4 => (380, (1.4, 1.8)),
87-
:b5 => (456, (1.6, 2.2)),
88-
:b6 => (528, (1.8, 2.6)),
89-
:b7 => (600, (2.0, 3.1)),
90-
:b8 => (672, (2.2, 3.6))
91-
)
83+
# (r, (w, d))
84+
:b0 => (224, (1.0, 1.0)),
85+
:b1 => (240, (1.0, 1.1)),
86+
:b2 => (260, (1.1, 1.2)),
87+
:b3 => (300, (1.2, 1.4)),
88+
:b4 => (380, (1.4, 1.8)),
89+
:b5 => (456, (1.6, 2.2)),
90+
:b6 => (528, (1.8, 2.6)),
91+
:b7 => (600, (2.0, 3.1)),
92+
:b8 => (672, (2.2, 3.6)))
9293

9394
struct EfficientNet
94-
layers::Any
95+
layers::Any
9596
end
9697

9798
"""
@@ -103,27 +104,29 @@ See also [`efficientnet`](#).
103104
104105
# Arguments
105106
106-
- `scalings`: global width and depth scaling (given as a tuple)
107-
- `block_config`: configuration for each inverted residual block,
108-
given as a vector of tuples with elements:
109-
- `n`: number of block repetitions (will be scaled by global depth scaling)
110-
- `k`: kernel size
111-
- `s`: kernel stride
112-
- `e`: expansion ratio
113-
- `i`: block input channels (will be scaled by global width scaling)
114-
- `o`: block output channels (will be scaled by global width scaling)
115-
- `inchannels`: number of input channels
116-
- `nclasses`: number of output classes
117-
- `max_width`: maximum number of output channels before the fully connected
118-
classification blocks
107+
- `scalings`: global width and depth scaling (given as a tuple)
108+
109+
- `block_config`: configuration for each inverted residual block,
110+
given as a vector of tuples with elements:
111+
112+
+ `n`: number of block repetitions (will be scaled by global depth scaling)
113+
+ `k`: kernel size
114+
+ `s`: kernel stride
115+
+ `e`: expansion ratio
116+
+ `i`: block input channels (will be scaled by global width scaling)
117+
+ `o`: block output channels (will be scaled by global width scaling)
118+
- `inchannels`: number of input channels
119+
- `nclasses`: number of output classes
120+
- `max_width`: maximum number of output channels before the fully connected
121+
classification blocks
119122
"""
120123
function EfficientNet(scalings, block_config;
121124
inchannels = 3, nclasses = 1000, max_width = 1280)
122-
layers = efficientnet(scalings, block_config;
123-
inchannels = inchannels,
124-
nclasses = nclasses,
125-
max_width = max_width)
126-
return EfficientNet(layers)
125+
layers = efficientnet(scalings, block_config;
126+
inchannels = inchannels,
127+
nclasses = nclasses,
128+
max_width = max_width)
129+
return EfficientNet(layers)
127130
end
128131

129132
@functor EfficientNet
@@ -141,13 +144,13 @@ See also [`efficientnet`](#).
141144
142145
# Arguments
143146
144-
- `name`: name of default configuration
145-
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
146-
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
147+
- `name`: name of default configuration
148+
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
149+
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
147150
"""
148151
function EfficientNet(name::Symbol; pretrain = false)
149152
@assert name in keys(efficientnet_global_configs)
150-
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"
153+
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"
151154

152155
model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
153156
pretrain && loadpretrain!(model, string("efficientnet-", name))

src/convnets/mobilenet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
2828
function mobilenetv1(width_mult, config;
2929
activation = relu,
3030
inchannels = 3,
31-
fcsize = 1024,
31+
fcsize = 1024,
3232
nclasses = 1000)
3333
layers = []
3434
for (dw, outch, stride, nrepeats) in config

0 commit comments

Comments
 (0)