1
- function dwsepconv_builder (block_configs, inplanes:: Integer , stage_idx:: Integer ,
2
- width_mult:: Real ; norm_layer = BatchNorm, kwargs... )
1
+ # TODO - potentially make these builders more flexible to specify stuff like
2
+ # activation functions and reductions that don't change over the stages
3
+
4
+ function dwsepconv_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ,
5
+ stage_idx:: Integer , scalings:: NTuple{2, Real} ;
6
+ norm_layer = BatchNorm, divisor:: Integer = 8 , kwargs... )
7
+ width_mult, depth_mult = scalings
3
8
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
4
- outplanes = _round_channels (outplanes * width_mult)
9
+ outplanes = _round_channels (outplanes * width_mult, divisor )
5
10
if stage_idx != 1
6
- inplanes = _round_channels (block_configs[stage_idx - 1 ][3 ] * width_mult)
11
+ inplanes = _round_channels (block_configs[stage_idx - 1 ][3 ] * width_mult, divisor )
7
12
end
8
13
function get_layers (block_idx:: Integer )
9
14
inplanes = block_idx == 1 ? inplanes : outplanes
@@ -12,13 +17,14 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
12
17
stride, pad = SamePad (), norm_layer, kwargs... )... )
13
18
return (block,)
14
19
end
15
- return get_layers, nrepeats
20
+ return get_layers, ceil (Int, nrepeats * depth_mult)
16
21
end
22
+ _get_builder (:: typeof (dwsep_conv_norm)) = dwsepconv_builder
17
23
18
- function mbconv_builder (block_configs, inplanes :: Integer , stage_idx :: Integer ,
19
- scalings:: NTuple{2, Real} ; norm_layer = BatchNorm,
20
- divisor :: Integer = 8 , se_from_explanes :: Bool = false ,
21
- kwargs... )
24
+ function mbconv_builder (block_configs:: AbstractVector{<:Tuple} , inplanes :: Integer ,
25
+ stage_idx :: Integer , scalings:: NTuple{2, Real} ;
26
+ norm_layer = BatchNorm, divisor :: Integer = 8 ,
27
+ se_from_explanes :: Bool = false , kwargs... )
22
28
width_mult, depth_mult = scalings
23
29
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
24
30
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
@@ -39,69 +45,31 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
39
45
end
40
46
return get_layers, ceil (Int, nrepeats * depth_mult)
41
47
end
48
+ _get_builder (:: typeof (mbconv)) = mbconv_builder
42
49
43
- function mbconv_builder (block_configs, inplanes:: Integer , stage_idx:: Integer ,
44
- width_mult:: Real ; norm_layer = BatchNorm, kwargs... )
45
- return mbconv_builder (block_configs, inplanes, stage_idx, (width_mult, 1 );
46
- norm_layer, kwargs... )
47
- end
48
-
49
- function fused_mbconv_builder (block_configs, inplanes:: Integer , stage_idx:: Integer ;
50
- norm_layer = BatchNorm, kwargs... )
50
+ function fused_mbconv_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ,
51
+ stage_idx:: Integer , scalings:: NTuple{2, Real} ;
52
+ norm_layer = BatchNorm, divisor:: Integer = 8 , kwargs... )
53
+ width_mult, depth_mult = scalings
51
54
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
52
55
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1 ][3 ]
56
+ outplanes = _round_channels (outplanes * width_mult, divisor)
53
57
function get_layers (block_idx:: Integer )
54
58
inplanes = block_idx == 1 ? inplanes : outplanes
55
- explanes = _round_channels (inplanes * expansion, 8 )
59
+ explanes = _round_channels (inplanes * expansion, divisor )
56
60
stride = block_idx == 1 ? stride : 1
57
61
block = block_fn ((k, k), inplanes, explanes, outplanes, activation;
58
62
norm_layer, stride, kwargs... )
59
63
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
60
64
end
61
- return get_layers, nrepeats
62
- end
63
-
64
- # TODO - these builders need to be more flexible to potentially specify stuff like
65
- # activation functions and reductions that don't change
66
- function _get_builder (:: typeof (dwsep_conv_bn), block_configs:: AbstractVector{<:Tuple} ,
67
- inplanes:: Integer , stage_idx:: Integer ;
68
- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
69
- width_mult:: Union{Nothing, Number} = nothing , norm_layer, kwargs... )
70
- @assert isnothing (scalings) " dwsep_conv_bn does not support the `scalings` argument"
71
- return dwsepconv_builder (block_configs, inplanes, stage_idx, width_mult; norm_layer,
72
- kwargs... )
73
- end
74
-
75
- function _get_builder (:: typeof (mbconv), block_configs:: AbstractVector{<:Tuple} ,
76
- inplanes:: Integer , stage_idx:: Integer ;
77
- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
78
- width_mult:: Union{Nothing, Number} = nothing , norm_layer, kwargs... )
79
- if isnothing (scalings)
80
- return mbconv_builder (block_configs, inplanes, stage_idx, width_mult; norm_layer,
81
- kwargs... )
82
- elseif isnothing (width_mult)
83
- return mbconv_builder (block_configs, inplanes, stage_idx, scalings; norm_layer,
84
- kwargs... )
85
- else
86
- throw (ArgumentError (" Only one of `scalings` and `width_mult` can be specified" ))
87
- end
88
- end
89
-
90
- function _get_builder (:: typeof (fused_mbconv), block_configs:: AbstractVector{<:Tuple} ,
91
- inplanes:: Integer , stage_idx:: Integer ;
92
- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
93
- width_mult:: Union{Nothing, Number} = nothing , norm_layer)
94
- @assert isnothing (width_mult) " fused_mbconv does not support the `width_mult` argument."
95
- @assert isnothing (scalings)|| scalings == (1 , 1 ) " fused_mbconv does not support the `scalings` argument"
96
- return fused_mbconv_builder (block_configs, inplanes, stage_idx; norm_layer)
65
+ return get_layers, ceil (Int, nrepeats * depth_mult)
97
66
end
67
+ _get_builder (:: typeof (fused_mbconv)) = fused_mbconv_builder
98
68
99
- function mbconv_stack_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ;
100
- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
101
- width_mult:: Union{Nothing, Number} = nothing ,
102
- norm_layer = BatchNorm, kwargs... )
103
- bxs = [_get_builder (block_configs[idx][1 ], block_configs, inplanes, idx; scalings,
104
- width_mult, norm_layer, kwargs... )
69
+ function mbconv_stage_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ,
70
+ scalings:: NTuple{2, Real} ; kwargs... )
71
+ builders = _get_builder .(first .(block_configs))
72
+ bxs = [builders[idx](block_configs, inplanes, idx, scalings; kwargs... )
105
73
for idx in eachindex (block_configs)]
106
74
return (stage_idx, block_idx) -> first .(bxs)[stage_idx](block_idx), last .(bxs)
107
75
end
0 commit comments