@@ -47,12 +47,15 @@ def __init__(
47
47
w_init : Optional [Any ],
48
48
name : Optional [str ] = None ,
49
49
):
50
+
50
51
super ().__init__ (name = name )
52
+
51
53
self .use_projection = use_projection
52
54
self .use_batch_norm = use_batch_norm
53
55
self .shortcut_weight = shortcut_weight
54
56
55
57
if self .use_projection and self .shortcut_weight != 0.0 :
58
+
56
59
self .proj_conv = hk .Conv2D (
57
60
output_channels = channels ,
58
61
kernel_shape = 1 ,
@@ -61,11 +64,13 @@ def __init__(
61
64
with_bias = not use_batch_norm ,
62
65
padding = "SAME" ,
63
66
name = "shortcut_conv" )
67
+
64
68
if use_batch_norm :
65
69
self .proj_batchnorm = hk .BatchNorm (
66
70
name = "shortcut_batchnorm" , ** BN_CONFIG )
67
71
68
72
channel_div = 4 if bottleneck else 1
73
+
69
74
conv_0 = hk .Conv2D (
70
75
output_channels = channels // channel_div ,
71
76
kernel_shape = 1 if bottleneck else 3 ,
@@ -87,8 +92,10 @@ def __init__(
87
92
layers = (conv_0 , conv_1 )
88
93
89
94
if use_batch_norm :
95
+
90
96
bn_0 = hk .BatchNorm (name = "batchnorm_0" , ** BN_CONFIG )
91
97
bn_1 = hk .BatchNorm (name = "batchnorm_1" , ** BN_CONFIG )
98
+
92
99
bn_layers = (bn_0 , bn_1 )
93
100
94
101
if bottleneck :
@@ -112,23 +119,31 @@ def __init__(
112
119
self .activation = activation
113
120
114
121
def __call__ (self , inputs , is_training , test_local_stats ):
122
+
115
123
out = shortcut = inputs
116
124
117
125
if self .use_projection and self .shortcut_weight != 0.0 :
126
+
118
127
shortcut = self .proj_conv (shortcut )
128
+
119
129
if self .use_batch_norm :
120
130
shortcut = self .proj_batchnorm (shortcut , is_training , test_local_stats )
121
131
122
132
for i , conv_i in enumerate (self .layers ):
133
+
123
134
out = conv_i (out )
135
+
124
136
if self .use_batch_norm :
125
137
out = self .bn_layers [i ](out , is_training , test_local_stats )
138
+
126
139
if i < len (self .layers ) - 1 : # Don't apply activation on last layer
127
140
out = self .activation (out )
128
141
129
142
if self .shortcut_weight is None :
130
143
return self .activation (out + shortcut )
144
+
131
145
elif self .shortcut_weight != 0.0 :
146
+
132
147
return self .activation (
133
148
math .sqrt (1 - self .shortcut_weight ** 2 ) * out +
134
149
self .shortcut_weight * shortcut )
@@ -151,12 +166,15 @@ def __init__(
151
166
w_init : Optional [Any ],
152
167
name : Optional [str ] = None ,
153
168
):
169
+
154
170
super ().__init__ (name = name )
171
+
155
172
self .use_projection = use_projection
156
173
self .use_batch_norm = use_batch_norm
157
174
self .shortcut_weight = shortcut_weight
158
175
159
176
if self .use_projection and self .shortcut_weight != 0.0 :
177
+
160
178
self .proj_conv = hk .Conv2D (
161
179
output_channels = channels ,
162
180
kernel_shape = 1 ,
@@ -167,6 +185,7 @@ def __init__(
167
185
name = "shortcut_conv" )
168
186
169
187
channel_div = 4 if bottleneck else 1
188
+
170
189
conv_0 = hk .Conv2D (
171
190
output_channels = channels // channel_div ,
172
191
kernel_shape = 1 if bottleneck else 3 ,
@@ -188,11 +207,14 @@ def __init__(
188
207
layers = (conv_0 , conv_1 )
189
208
190
209
if use_batch_norm :
210
+
191
211
bn_0 = hk .BatchNorm (name = "batchnorm_0" , ** BN_CONFIG )
192
212
bn_1 = hk .BatchNorm (name = "batchnorm_1" , ** BN_CONFIG )
213
+
193
214
bn_layers = (bn_0 , bn_1 )
194
215
195
216
if bottleneck :
217
+
196
218
conv_2 = hk .Conv2D (
197
219
output_channels = channels ,
198
220
kernel_shape = 1 ,
@@ -205,8 +227,10 @@ def __init__(
205
227
layers = layers + (conv_2 ,)
206
228
207
229
if use_batch_norm :
230
+
208
231
bn_2 = hk .BatchNorm (name = "batchnorm_2" , ** BN_CONFIG )
209
232
bn_layers += (bn_2 ,)
233
+
210
234
self .bn_layers = bn_layers
211
235
212
236
self .layers = layers
@@ -229,9 +253,11 @@ def __call__(self, inputs, is_training, test_local_stats):
229
253
230
254
if self .shortcut_weight is None :
231
255
return x + shortcut
256
+
232
257
elif self .shortcut_weight != 0.0 :
233
258
return math .sqrt (
234
259
1 - self .shortcut_weight ** 2 ) * x + self .shortcut_weight * shortcut
260
+
235
261
else :
236
262
return x
237
263
@@ -272,13 +298,17 @@ def __init__(
272
298
name = "block_%d" % (i )))
273
299
274
300
def __call__ (self , inputs , is_training , test_local_stats ):
301
+
275
302
out = inputs
303
+
276
304
for block in self .blocks :
277
305
out = block (out , is_training , test_local_stats )
306
+
278
307
return out
279
308
280
309
281
310
def check_length (length , value , name ):
311
+
282
312
if len (value ) != length :
283
313
raise ValueError (f"`{ name } ` must be of length 4 not { len (value )} " )
284
314
@@ -481,12 +511,15 @@ def __init__(
481
511
self .logits = hk .Linear (num_classes , ** logits_config )
482
512
483
513
def __call__ (self , inputs , is_training , test_local_stats = False ):
514
+
484
515
out = inputs
485
516
out = self .initial_conv (out )
486
517
487
518
if not self .resnet_v2 :
519
+
488
520
if self .use_batch_norm :
489
521
out = self .initial_batchnorm (out , is_training , test_local_stats )
522
+
490
523
out = self .activation (out )
491
524
492
525
out = hk .max_pool (
@@ -525,15 +558,18 @@ def subnet_max_func(x, r_fn, depth, shortcut_weight, resnet_v2=True):
525
558
526
559
if bottleneck and resnet_v2 :
527
560
res_fn = lambda z : r_fn (r_fn (r_fn (z )))
561
+
528
562
elif (not bottleneck and resnet_v2 ) or (bottleneck and not resnet_v2 ):
529
563
res_fn = lambda z : r_fn (r_fn (z ))
564
+
530
565
else :
531
566
res_fn = r_fn
532
567
533
568
res_branch_subnetwork = res_fn (x )
534
569
535
570
for i in range (4 ):
536
571
for j in range (blocks_per_group [i ]):
572
+
537
573
res_x = res_fn (x )
538
574
539
575
if j == 0 and use_projection [i ] and resnet_v2 :
0 commit comments