1
1
"""
2
2
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
3
- rev = false, preact = true,
4
- stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init] ,
5
- initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1 )
3
+ rev = false, preact = false, use_bn = true,
4
+ initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1.0f-5, momentum = 1.0f-1 ,
5
+ kwargs... )
6
6
7
7
Create a convolution + batch normalization pair with activation.
8
8
@@ -15,6 +15,8 @@ Create a convolution + batch normalization pair with activation.
15
15
- `rev`: set to `true` to place the batch norm before the convolution
16
16
- `preact`: set to `true` to place the activation function before the batch norm
17
17
(only compatible with `rev = false`)
18
+ - `use_bn`: set to `false` to disable batch normalization
19
+ (only compatible with `rev = false` and `preact = false`)
18
20
- `stride`: stride of the convolution kernel
19
21
- `pad`: padding of the convolution kernel
20
22
- `dilation`: dilation of the convolution kernel
@@ -24,9 +26,13 @@ Create a convolution + batch normalization pair with activation.
24
26
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
25
27
"""
26
28
function conv_bn (kernelsize, inplanes, outplanes, activation = relu;
27
- rev = false , preact = false ,
29
+ rev = false , preact = false , use_bn = true ,
28
30
initβ = Flux. zeros32, initγ = Flux. ones32, ϵ = 1.0f-5 , momentum = 1.0f-1 ,
29
31
kwargs... )
32
+ if ! use_bn
33
+ (preact || rev) ? throw (" preact only supported with `use_bn = true`" ) :
34
+ return [Conv (kernelsize, inplanes => outplanes, activation; kwargs... )]
35
+ end
30
36
layers = []
31
37
if rev
32
38
activations = (conv = activation, bn = identity)
49
55
50
56
"""
51
57
depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu;
52
- rev = false,
53
- stride = 1, pad = 0, dilation = 1, [bias, weight, init] ,
54
- initβ = Flux.zeros32, initγ = Flux.ones32 ,
55
- ϵ = 1f-5, momentum = 1f-1 )
58
+ rev = false, use_bn1 = true, use_bn2 = true ,
59
+ initβ = Flux.zeros32, initγ = Flux.ones32 ,
60
+ ϵ = 1.0f-5, momentum = 1.0f-1 ,
61
+ stride = 1, kwargs... )
56
62
57
- Create a depthwise separable convolution chain as used in MobileNet v1 .
63
+ Create a depthwise separable convolution chain as used in MobileNetv1 .
58
64
This is sequence of layers:
59
65
60
66
- a `kernelsize` depthwise convolution from `inplanes => inplanes`
61
- - a batch norm layer + `activation`
67
+ - a batch norm layer + `activation` (if `use_bn1`; otherwise `activation` is applied to the convolution output)
62
68
- a `kernelsize` convolution from `inplanes => outplanes`
63
- - a batch norm layer + `activation`
69
+ - a batch norm layer + `activation` (if `use_bn2`; otherwise `activation` is applied to the convolution output)
64
70
65
71
See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
66
72
@@ -71,6 +77,8 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
71
77
- `outplanes`: number of output feature maps
72
78
- `activation`: the activation function for the final layer
73
79
- `rev`: set to `true` to place the batch norm before the convolution
80
+ - `use_bn1`: set to `true` to use a batch norm after the depthwise convolution
81
+ - `use_bn2`: set to `true` to use a batch norm after the pointwise convolution
74
82
- `stride`: stride of the first convolution kernel
75
83
- `pad`: padding of the first convolution kernel
76
84
- `dilation`: dilation of the first convolution kernel
@@ -79,16 +87,16 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
79
87
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
80
88
"""
81
89
function depthwise_sep_conv_bn (kernelsize, inplanes, outplanes, activation = relu;
82
- rev = false ,
90
+ rev = false , use_bn1 = true , use_bn2 = true ,
83
91
initβ = Flux. zeros32, initγ = Flux. ones32,
84
92
ϵ = 1.0f-5 , momentum = 1.0f-1 ,
85
93
stride = 1 , kwargs... )
86
94
return vcat (conv_bn (kernelsize, inplanes, inplanes, activation;
87
95
rev = rev, initβ = initβ, initγ = initγ,
88
- ϵ = ϵ, momentum = momentum,
96
+ ϵ = ϵ, momentum = momentum, use_bn = use_bn1,
89
97
stride = stride, groups = Int (inplanes), kwargs... ),
90
98
conv_bn ((1 , 1 ), inplanes, outplanes, activation;
91
- rev = rev, initβ = initβ, initγ = initγ,
99
+ rev = rev, initβ = initβ, initγ = initγ, use_bn = use_bn2,
92
100
ϵ = ϵ, momentum = momentum))
93
101
end
94
102
0 commit comments