@@ -5,19 +5,20 @@ Creates a single block of ConvNeXt.
5
5
([reference](https://arxiv.org/abs/2201.03545))
6
6
7
7
# Arguments:
8
- - `planes`: number of input channels.
9
- - `drop_path_rate`: Stochastic depth rate.
10
- - `λ`: Init value for LayerScale
8
+
9
+ - `planes`: number of input channels.
10
+ - `drop_path_rate`: Stochastic depth rate.
11
+ - `λ`: Init value for LayerScale
11
12
"""
12
- function convnextblock (planes, drop_path_rate = 0. , λ = 1f -6 )
13
- layers = SkipConnection (Chain (DepthwiseConv ((7 , 7 ), planes => planes; pad = 3 ),
14
- swapdims ((3 , 1 , 2 , 4 )),
15
- LayerNorm (planes; ϵ = 1f -6 ),
16
- mlp_block (planes, 4 * planes),
17
- LayerScale (planes, λ),
18
- swapdims ((2 , 3 , 1 , 4 )),
19
- DropPath (drop_path_rate)), + )
20
- return layers
13
+ function convnextblock (planes, drop_path_rate = 0.0 , λ = 1.0f -6 )
14
+ layers = SkipConnection (Chain (DepthwiseConv ((7 , 7 ), planes => planes; pad = 3 ),
15
+ swapdims ((3 , 1 , 2 , 4 )),
16
+ LayerNorm (planes; ϵ = 1.0f -6 ),
17
+ mlp_block (planes, 4 * planes),
18
+ LayerScale (planes, λ),
19
+ swapdims ((2 , 3 , 1 , 4 )),
20
+ DropPath (drop_path_rate)), + )
21
+ return layers
21
22
end
22
23
23
24
"""
@@ -27,52 +28,59 @@ Creates the layers for a ConvNeXt model.
27
28
([reference](https://arxiv.org/abs/2201.03545))
28
29
29
30
# Arguments:
30
- - `inchannels`: number of input channels.
31
- - `depths`: list with configuration for depth of each block
32
- - `planes`: list with configuration for number of output channels in each block
33
- - `drop_path_rate`: Stochastic depth rate.
34
- - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
35
- - `nclasses`: number of output classes
31
+
32
+ - `inchannels`: number of input channels.
33
+ - `depths`: list with configuration for depth of each block
34
+ - `planes`: list with configuration for number of output channels in each block
35
+ - `drop_path_rate`: Stochastic depth rate.
36
+ - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
37
+ - `nclasses`: number of output classes
36
38
"""
37
- function convnext (depths, planes; inchannels = 3 , drop_path_rate = 0. , λ = 1f-6 , nclasses = 1000 )
38
- @assert length (depths) == length (planes) " `planes` should have exactly one value for each block"
39
-
40
- downsample_layers = []
41
- stem = Chain (Conv ((4 , 4 ), inchannels => planes[1 ]; stride = 4 ),
42
- ChannelLayerNorm (planes[1 ]; ϵ = 1f-6 ))
43
- push! (downsample_layers, stem)
44
- for m in 1 : length (depths) - 1
45
- downsample_layer = Chain (ChannelLayerNorm (planes[m]; ϵ = 1f-6 ),
46
- Conv ((2 , 2 ), planes[m] => planes[m + 1 ]; stride = 2 ))
47
- push! (downsample_layers, downsample_layer)
48
- end
49
-
50
- stages = []
51
- dp_rates = LinRange {Float32} (0. , drop_path_rate, sum (depths))
52
- cur = 0
53
- for i in 1 : length (depths)
54
- push! (stages, [convnextblock (planes[i], dp_rates[cur + j], λ) for j in 1 : depths[i]])
55
- cur += depths[i]
56
- end
57
-
58
- backbone = collect (Iterators. flatten (Iterators. flatten (zip (downsample_layers, stages))))
59
- head = Chain (GlobalMeanPool (),
60
- MLUtils. flatten,
61
- LayerNorm (planes[end ]),
62
- Dense (planes[end ], nclasses))
63
-
64
- return Chain (Chain (backbone), head)
39
+ function convnext (depths, planes; inchannels = 3 , drop_path_rate = 0.0 , λ = 1.0f-6 ,
40
+ nclasses = 1000 )
41
+ @assert length (depths)== length (planes) " `planes` should have exactly one value for each block"
42
+
43
+ downsample_layers = []
44
+ stem = Chain (Conv ((4 , 4 ), inchannels => planes[1 ]; stride = 4 ),
45
+ ChannelLayerNorm (planes[1 ]; ϵ = 1.0f-6 ))
46
+ push! (downsample_layers, stem)
47
+ for m in 1 : (length (depths) - 1 )
48
+ downsample_layer = Chain (ChannelLayerNorm (planes[m]; ϵ = 1.0f-6 ),
49
+ Conv ((2 , 2 ), planes[m] => planes[m + 1 ]; stride = 2 ))
50
+ push! (downsample_layers, downsample_layer)
51
+ end
52
+
53
+ stages = []
54
+ dp_rates = LinRange {Float32} (0.0 , drop_path_rate, sum (depths))
55
+ cur = 0
56
+ for i in 1 : length (depths)
57
+ push! (stages, [convnextblock (planes[i], dp_rates[cur + j], λ) for j in 1 : depths[i]])
58
+ cur += depths[i]
59
+ end
60
+
61
+ backbone = collect (Iterators. flatten (Iterators. flatten (zip (downsample_layers, stages))))
62
+ head = Chain (GlobalMeanPool (),
63
+ MLUtils. flatten,
64
+ LayerNorm (planes[end ]),
65
+ Dense (planes[end ], nclasses))
66
+
67
+ return Chain (Chain (backbone), head)
65
68
end
66
69
67
70
# Configurations for ConvNeXt models
68
- convnext_configs = Dict (:tiny => Dict (:depths => [3 , 3 , 9 , 3 ], :planes => [96 , 192 , 384 , 768 ]),
69
- :small => Dict (:depths => [3 , 3 , 27 , 3 ], :planes => [96 , 192 , 384 , 768 ]),
70
- :base => Dict (:depths => [3 , 3 , 27 , 3 ], :planes => [128 , 256 , 512 , 1024 ]),
71
- :large => Dict (:depths => [3 , 3 , 27 , 3 ], :planes => [192 , 384 , 768 , 1536 ]),
72
- :xlarge => Dict (:depths => [3 , 3 , 27 , 3 ], :planes => [256 , 512 , 1024 , 2048 ]))
71
+ convnext_configs = Dict (:tiny => Dict (:depths => [3 , 3 , 9 , 3 ],
72
+ :planes => [96 , 192 , 384 , 768 ]),
73
+ :small => Dict (:depths => [3 , 3 , 27 , 3 ],
74
+ :planes => [96 , 192 , 384 , 768 ]),
75
+ :base => Dict (:depths => [3 , 3 , 27 , 3 ],
76
+ :planes => [128 , 256 , 512 , 1024 ]),
77
+ :large => Dict (:depths => [3 , 3 , 27 , 3 ],
78
+ :planes => [192 , 384 , 768 , 1536 ]),
79
+ :xlarge => Dict (:depths => [3 , 3 , 27 , 3 ],
80
+ :planes => [256 , 512 , 1024 , 2048 ]))
73
81
74
82
struct ConvNeXt
75
- layers
83
+ layers:: Any
76
84
end
77
85
78
86
"""
@@ -82,20 +90,21 @@ Creates a ConvNeXt model.
82
90
([reference](https://arxiv.org/abs/2201.03545))
83
91
84
92
# Arguments:
85
- - `inchannels`: number of input channels.
86
- - `drop_path_rate`: Stochastic depth rate.
87
- - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
88
- - `nclasses`: number of output classes
93
+
94
+ - `inchannels`: number of input channels.
95
+ - `drop_path_rate`: Stochastic depth rate.
96
+ - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
97
+ - `nclasses`: number of output classes
89
98
90
99
See also [`Metalhead.convnext`](#).
91
100
"""
92
- function ConvNeXt (mode:: Symbol = :base ; inchannels = 3 , drop_path_rate = 0. , λ = 1f -6 ,
101
+ function ConvNeXt (mode:: Symbol = :base ; inchannels = 3 , drop_path_rate = 0.0 , λ = 1.0f -6 ,
93
102
nclasses = 1000 )
94
- @assert mode in keys (convnext_configs) " `size` must be one of $(collect (keys (convnext_configs))) "
95
- depths = convnext_configs[mode][:depths ]
96
- planes = convnext_configs[mode][:planes ]
97
- layers = convnext (depths, planes; inchannels, drop_path_rate, λ, nclasses)
98
- return ConvNeXt (layers)
103
+ @assert mode in keys (convnext_configs) " `size` must be one of $(collect (keys (convnext_configs))) "
104
+ depths = convnext_configs[mode][:depths ]
105
+ planes = convnext_configs[mode][:planes ]
106
+ layers = convnext (depths, planes; inchannels, drop_path_rate, λ, nclasses)
107
+ return ConvNeXt (layers)
99
108
end
100
109
101
110
(m:: ConvNeXt )(x) = m. layers (x)
0 commit comments