1
1
function mixed_5b ()
2
- branch1 = Chain (conv_norm ((1 , 1 ), 192 , 96 )... )
3
- branch2 = Chain (conv_norm ((1 , 1 ), 192 , 48 )... ,
4
- conv_norm ((5 , 5 ), 48 , 64 ; pad = 2 )... )
5
- branch3 = Chain (conv_norm ((1 , 1 ), 192 , 64 )... ,
6
- conv_norm ((3 , 3 ), 64 , 96 ; pad = 1 )... ,
7
- conv_norm ((3 , 3 ), 96 , 96 ; pad = 1 )... )
2
+ branch1 = Chain (basic_conv_bn ((1 , 1 ), 192 , 96 )... )
3
+ branch2 = Chain (basic_conv_bn ((1 , 1 ), 192 , 48 )... ,
4
+ basic_conv_bn ((5 , 5 ), 48 , 64 ; pad = 2 )... )
5
+ branch3 = Chain (basic_conv_bn ((1 , 1 ), 192 , 64 )... ,
6
+ basic_conv_bn ((3 , 3 ), 64 , 96 ; pad = 1 )... ,
7
+ basic_conv_bn ((3 , 3 ), 96 , 96 ; pad = 1 )... )
8
8
branch4 = Chain (MeanPool ((3 , 3 ); pad = 1 , stride = 1 ),
9
- conv_norm ((1 , 1 ), 192 , 64 )... )
9
+ basic_conv_bn ((1 , 1 ), 192 , 64 )... )
10
10
return Parallel (cat_channels, branch1, branch2, branch3, branch4)
11
11
end
12
12
13
13
function block35 (scale = 1.0f0 )
14
- branch1 = Chain (conv_norm ((1 , 1 ), 320 , 32 )... )
15
- branch2 = Chain (conv_norm ((1 , 1 ), 320 , 32 )... ,
16
- conv_norm ((3 , 3 ), 32 , 32 ; pad = 1 )... )
17
- branch3 = Chain (conv_norm ((1 , 1 ), 320 , 32 )... ,
18
- conv_norm ((3 , 3 ), 32 , 48 ; pad = 1 )... ,
19
- conv_norm ((3 , 3 ), 48 , 64 ; pad = 1 )... )
20
- branch4 = Chain (conv_norm ((1 , 1 ), 128 , 320 )... )
14
+ branch1 = Chain (basic_conv_bn ((1 , 1 ), 320 , 32 )... )
15
+ branch2 = Chain (basic_conv_bn ((1 , 1 ), 320 , 32 )... ,
16
+ basic_conv_bn ((3 , 3 ), 32 , 32 ; pad = 1 )... )
17
+ branch3 = Chain (basic_conv_bn ((1 , 1 ), 320 , 32 )... ,
18
+ basic_conv_bn ((3 , 3 ), 32 , 48 ; pad = 1 )... ,
19
+ basic_conv_bn ((3 , 3 ), 48 , 64 ; pad = 1 )... )
20
+ branch4 = Chain (basic_conv_bn ((1 , 1 ), 128 , 320 )... )
21
21
return SkipConnection (Chain (Parallel (cat_channels, branch1, branch2, branch3),
22
22
branch4, inputscale (scale; activation = relu)), + )
23
23
end
24
24
25
25
function mixed_6a ()
26
- branch1 = Chain (conv_norm ((3 , 3 ), 320 , 384 ; stride = 2 )... )
27
- branch2 = Chain (conv_norm ((1 , 1 ), 320 , 256 )... ,
28
- conv_norm ((3 , 3 ), 256 , 256 ; pad = 1 )... ,
29
- conv_norm ((3 , 3 ), 256 , 384 ; stride = 2 )... )
26
+ branch1 = Chain (basic_conv_bn ((3 , 3 ), 320 , 384 ; stride = 2 )... )
27
+ branch2 = Chain (basic_conv_bn ((1 , 1 ), 320 , 256 )... ,
28
+ basic_conv_bn ((3 , 3 ), 256 , 256 ; pad = 1 )... ,
29
+ basic_conv_bn ((3 , 3 ), 256 , 384 ; stride = 2 )... )
30
30
branch3 = MaxPool ((3 , 3 ); stride = 2 )
31
31
return Parallel (cat_channels, branch1, branch2, branch3)
32
32
end
33
33
34
34
function block17 (scale = 1.0f0 )
35
- branch1 = Chain (conv_norm ((1 , 1 ), 1088 , 192 )... )
36
- branch2 = Chain (conv_norm ((1 , 1 ), 1088 , 128 )... ,
37
- conv_norm ((7 , 1 ), 128 , 160 ; pad = (3 , 0 ))... ,
38
- conv_norm ((1 , 7 ), 160 , 192 ; pad = (0 , 3 ))... )
39
- branch3 = Chain (conv_norm ((1 , 1 ), 384 , 1088 )... )
35
+ branch1 = Chain (basic_conv_bn ((1 , 1 ), 1088 , 192 )... )
36
+ branch2 = Chain (basic_conv_bn ((1 , 1 ), 1088 , 128 )... ,
37
+ basic_conv_bn ((7 , 1 ), 128 , 160 ; pad = (3 , 0 ))... ,
38
+ basic_conv_bn ((1 , 7 ), 160 , 192 ; pad = (0 , 3 ))... )
39
+ branch3 = Chain (basic_conv_bn ((1 , 1 ), 384 , 1088 )... )
40
40
return SkipConnection (Chain (Parallel (cat_channels, branch1, branch2),
41
41
branch3, inputscale (scale; activation = relu)), + )
42
42
end
43
43
44
44
function mixed_7a ()
45
- branch1 = Chain (conv_norm ((1 , 1 ), 1088 , 256 )... ,
46
- conv_norm ((3 , 3 ), 256 , 384 ; stride = 2 )... )
47
- branch2 = Chain (conv_norm ((1 , 1 ), 1088 , 256 )... ,
48
- conv_norm ((3 , 3 ), 256 , 288 ; stride = 2 )... )
49
- branch3 = Chain (conv_norm ((1 , 1 ), 1088 , 256 )... ,
50
- conv_norm ((3 , 3 ), 256 , 288 ; pad = 1 )... ,
51
- conv_norm ((3 , 3 ), 288 , 320 ; stride = 2 )... )
45
+ branch1 = Chain (basic_conv_bn ((1 , 1 ), 1088 , 256 )... ,
46
+ basic_conv_bn ((3 , 3 ), 256 , 384 ; stride = 2 )... )
47
+ branch2 = Chain (basic_conv_bn ((1 , 1 ), 1088 , 256 )... ,
48
+ basic_conv_bn ((3 , 3 ), 256 , 288 ; stride = 2 )... )
49
+ branch3 = Chain (basic_conv_bn ((1 , 1 ), 1088 , 256 )... ,
50
+ basic_conv_bn ((3 , 3 ), 256 , 288 ; pad = 1 )... ,
51
+ basic_conv_bn ((3 , 3 ), 288 , 320 ; stride = 2 )... )
52
52
branch4 = MaxPool ((3 , 3 ); stride = 2 )
53
53
return Parallel (cat_channels, branch1, branch2, branch3, branch4)
54
54
end
55
55
56
56
function block8 (scale = 1.0f0 ; activation = identity)
57
- branch1 = Chain (conv_norm ((1 , 1 ), 2080 , 192 )... )
58
- branch2 = Chain (conv_norm ((1 , 1 ), 2080 , 192 )... ,
59
- conv_norm ((3 , 1 ), 192 , 224 ; pad = (1 , 0 ))... ,
60
- conv_norm ((1 , 3 ), 224 , 256 ; pad = (0 , 1 ))... )
61
- branch3 = Chain (conv_norm ((1 , 1 ), 448 , 2080 )... )
57
+ branch1 = Chain (basic_conv_bn ((1 , 1 ), 2080 , 192 )... )
58
+ branch2 = Chain (basic_conv_bn ((1 , 1 ), 2080 , 192 )... ,
59
+ basic_conv_bn ((3 , 1 ), 192 , 224 ; pad = (1 , 0 ))... ,
60
+ basic_conv_bn ((1 , 3 ), 224 , 256 ; pad = (0 , 1 ))... )
61
+ branch3 = Chain (basic_conv_bn ((1 , 1 ), 448 , 2080 )... )
62
62
return SkipConnection (Chain (Parallel (cat_channels, branch1, branch2),
63
63
branch3, inputscale (scale; activation)), + )
64
64
end
@@ -77,12 +77,12 @@ Creates an InceptionResNetv2 model.
77
77
"""
78
78
function inceptionresnetv2 (; dropout_rate = 0.0 , inchannels:: Integer = 3 ,
79
79
nclasses:: Integer = 1000 )
80
- backbone = Chain (conv_norm ((3 , 3 ), inchannels, 32 ; stride = 2 )... ,
81
- conv_norm ((3 , 3 ), 32 , 32 )... ,
82
- conv_norm ((3 , 3 ), 32 , 64 ; pad = 1 )... ,
80
+ backbone = Chain (basic_conv_bn ((3 , 3 ), inchannels, 32 ; stride = 2 )... ,
81
+ basic_conv_bn ((3 , 3 ), 32 , 32 )... ,
82
+ basic_conv_bn ((3 , 3 ), 32 , 64 ; pad = 1 )... ,
83
83
MaxPool ((3 , 3 ); stride = 2 ),
84
- conv_norm ((3 , 3 ), 64 , 80 )... ,
85
- conv_norm ((3 , 3 ), 80 , 192 )... ,
84
+ basic_conv_bn ((3 , 3 ), 64 , 80 )... ,
85
+ basic_conv_bn ((3 , 3 ), 80 , 192 )... ,
86
86
MaxPool ((3 , 3 ); stride = 2 ),
87
87
mixed_5b (),
88
88
[block35 (0.17f0 ) for _ in 1 : 10 ]. .. ,
@@ -91,7 +91,7 @@ function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3,
91
91
mixed_7a (),
92
92
[block8 (0.20f0 ) for _ in 1 : 9 ]. .. ,
93
93
block8 (; activation = relu),
94
- conv_norm ((1 , 1 ), 2080 , 1536 )... )
94
+ basic_conv_bn ((1 , 1 ), 2080 , 1536 )... )
95
95
return Chain (backbone, create_classifier (1536 , nclasses; dropout_rate))
96
96
end
97
97
0 commit comments