Skip to content

Commit 1a8f590

Browse files
committed
Update EfficientNet feature extraction for EfficientDet. Add needed MaxPoolSame as well.
1 parent e01ccb8 commit 1a8f590

File tree

9 files changed

+182
-116
lines changed

9 files changed

+182
-116
lines changed

timm/models/efficientnet.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3,
326326
# Stem
327327
if not fix_stem:
328328
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
329-
print(stem_size)
330329
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
331330
self.bn1 = norm_layer(stem_size, **norm_kwargs)
332331
self.act1 = act_layer(inplace=True)
@@ -393,7 +392,7 @@ class EfficientNetFeatures(nn.Module):
393392
and object detection models.
394393
"""
395394

396-
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
395+
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
397396
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
398397
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
399398
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@@ -404,6 +403,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
404403
num_stages = max(out_indices) + 1
405404

406405
self.out_indices = out_indices
406+
self.feature_location = feature_location
407407
self.drop_rate = drop_rate
408408
self._in_chs = in_chans
409409

@@ -420,34 +420,56 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
420420
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
421421
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
422422
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
423-
self.feature_info = builder.features # builder provides info about feature channels for each block
423+
self._feature_info = builder.features # builder provides info about feature channels for each block
424+
self._stage_to_feature_idx = {
425+
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
424426
self._in_chs = builder.in_chs
425427

426428
efficientnet_init_weights(self)
427429
if _DEBUG:
428-
for k, v in self.feature_info.items():
430+
for k, v in self._feature_info.items():
429431
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
430432

431433
# Register feature extraction hooks with FeatureHooks helper
432-
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
433-
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
434-
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
434+
self.feature_hooks = None
435+
if feature_location != 'bottleneck':
436+
hooks = [dict(
437+
name=self._feature_info[idx]['module'],
438+
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
439+
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
435440

436441
def feature_channels(self, idx=None):
437442
""" Feature Channel Shortcut
438443
Returns feature channel count for each output index if idx == None. If idx is an integer, will
439444
return feature channel count for that feature block index (independent of out_indices setting).
440445
"""
441446
if isinstance(idx, int):
442-
return self.feature_info[idx]['num_chs']
443-
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
447+
return self._feature_info[idx]['num_chs']
448+
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
449+
450+
def feature_info(self, idx=None):
451+
""" Feature Channel Shortcut
452+
Returns feature channel count for each output index if idx == None. If idx is an integer, will
453+
return feature channel count for that feature block index (independent of out_indices setting).
454+
"""
455+
if isinstance(idx, int):
456+
return self._feature_info[idx]
457+
return [self._feature_info[i] for i in self.out_indices]
444458

445459
def forward(self, x):
446460
x = self.conv_stem(x)
447461
x = self.bn1(x)
448462
x = self.act1(x)
449-
self.blocks(x)
450-
return self.feature_hooks.get_output(x.device)
463+
if self.feature_hooks is None:
464+
features = []
465+
for i, b in enumerate(self.blocks):
466+
x = b(x)
467+
if i in self._stage_to_feature_idx:
468+
features.append(x)
469+
return features
470+
else:
471+
self.blocks(x)
472+
return self.feature_hooks.get_output(x.device)
451473

452474

453475
def _create_model(model_kwargs, default_cfg, pretrained=False):

timm/models/efficientnet_blocks.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ def __init__(self, in_chs, out_chs, kernel_size,
120120
self.bn1 = norm_layer(out_chs, **norm_kwargs)
121121
self.act1 = act_layer(inplace=True)
122122

123-
def feature_module(self, location):
124-
return 'act1'
125-
126-
def feature_channels(self, location):
127-
return self.conv.out_channels
123+
def feature_info(self, location):
124+
if location == 'expansion' or location == 'depthwise':
125+
# no expansion or depthwise this block, use act after conv
126+
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
127+
else: # location == 'bottleneck'
128+
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
129+
return info
128130

129131
def forward(self, x):
130132
x = self.conv(x)
@@ -165,12 +167,15 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
165167
self.bn2 = norm_layer(out_chs, **norm_kwargs)
166168
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
167169

168-
def feature_module(self, location):
169-
# no expansion in this block, pre pw only feature extraction point
170-
return 'conv_pw'
171-
172-
def feature_channels(self, location):
173-
return self.conv_pw.in_channels
170+
def feature_info(self, location):
171+
if location == 'expansion':
172+
# no expansion in this block, use depthwise, before SE
173+
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
174+
elif location == 'depthwise': # after SE
175+
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
176+
else: # location == 'bottleneck'
177+
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
178+
return info
174179

175180
def forward(self, x):
176181
residual = x
@@ -232,16 +237,14 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
232237
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
233238
self.bn3 = norm_layer(out_chs, **norm_kwargs)
234239

235-
def feature_module(self, location):
236-
if location == 'post_exp':
237-
return 'act1'
238-
return 'conv_pwl'
239-
240-
def feature_channels(self, location):
241-
if location == 'post_exp':
242-
return self.conv_pw.out_channels
243-
# location == 'pre_pw'
244-
return self.conv_pwl.in_channels
240+
def feature_info(self, location):
241+
if location == 'expansion':
242+
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
243+
elif location == 'depthwise': # after SE
244+
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
245+
else: # location == 'bottleneck'
246+
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
247+
return info
245248

246249
def forward(self, x):
247250
residual = x
@@ -359,16 +362,15 @@ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_ch
359362
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
360363
self.bn2 = norm_layer(out_chs, **norm_kwargs)
361364

362-
def feature_module(self, location):
363-
if location == 'post_exp':
364-
return 'act1'
365-
return 'conv_pwl'
366-
367-
def feature_channels(self, location):
368-
if location == 'post_exp':
369-
return self.conv_exp.out_channels
370-
# location == 'pre_pw'
371-
return self.conv_pwl.in_channels
365+
def feature_info(self, location):
366+
if location == 'expansion':
367+
info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
368+
elif location == 'depthwise':
369+
# there is no depthwise, take after SE, before PWL
370+
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
371+
else: # location == 'bottleneck'
372+
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
373+
return info
372374

373375
def forward(self, x):
374376
residual = x

timm/models/efficientnet_builder.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
218218
self.norm_kwargs = norm_kwargs
219219
self.drop_path_rate = drop_path_rate
220220
self.feature_location = feature_location
221-
assert feature_location in ('pre_pwl', 'post_exp', '')
221+
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
222222
self.verbose = verbose
223223

224224
# state updated during build, consumed by model
@@ -313,20 +313,21 @@ def __call__(self, in_chs, model_block_args):
313313
block_args['stride'] = 1
314314

315315
do_extract = False
316-
if self.feature_location == 'pre_pwl':
316+
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
317317
if last_block:
318318
next_stage_idx = stage_idx + 1
319319
if next_stage_idx >= len(model_block_args):
320320
do_extract = True
321321
else:
322322
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
323-
elif self.feature_location == 'post_exp':
324-
if block_args['stride'] > 1 or (last_stack and last_block) :
323+
elif self.feature_location == 'expansion':
324+
if block_args['stride'] > 1 or (last_stack and last_block):
325325
do_extract = True
326326
if do_extract:
327327
extract_features = self.feature_location
328328

329329
next_dilation = current_dilation
330+
next_output_stride = current_stride
330331
if block_args['stride'] > 1:
331332
next_output_stride = current_stride * block_args['stride']
332333
if next_output_stride > self.output_stride:
@@ -347,14 +348,13 @@ def __call__(self, in_chs, model_block_args):
347348

348349
# stash feature module name and channel info for model feature extraction
349350
if extract_features:
350-
feature_module = block.feature_module(extract_features)
351-
if feature_module:
352-
feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module
353-
feature_channels = block.feature_channels(extract_features)
354-
self.features[feature_idx] = dict(
355-
name=feature_module,
356-
num_chs=feature_channels
357-
)
351+
feature_info = block.feature_info(extract_features)
352+
if feature_info['module']:
353+
feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
354+
feature_info['stage_idx'] = stage_idx
355+
feature_info['block_idx'] = block_idx
356+
feature_info['reduction'] = current_stride
357+
self.features[feature_idx] = feature_info
358358
feature_idx += 1
359359

360360
total_block_idx += 1 # incr global block idx (across all stacks)

timm/models/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .padding import get_padding
2-
from .avg_pool2d_same import AvgPool2dSame
2+
from .pool2d_same import AvgPool2dSame
33
from .conv2d_same import Conv2dSame
44
from .conv_bn_act import ConvBnAct
55
from .mixed_conv2d import MixedConv2d
66
from .cond_conv2d import CondConv2d, get_condconv_initializer
7+
from .pool2d_same import create_pool2d
78
from .create_conv2d import create_conv2d
89
from .create_attn import create_attn
910
from .selective_kernel import SelectiveKernelConv

timm/models/layers/avg_pool2d_same.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

timm/models/layers/cond_conv2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from torch.nn import functional as F
1515

1616
from .helpers import tup_pair
17-
from .conv2d_same import get_padding_value, conv2d_same
17+
from .conv2d_same import conv2d_same
18+
from timm.models.layers.padding import get_padding_value
1819

1920

2021
def get_condconv_initializer(initializer, num_experts, expert_shape):

timm/models/layers/conv2d_same.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from typing import Union, List, Tuple, Optional, Callable
9-
import math
8+
from typing import Tuple, Optional
109

11-
from .padding import get_padding, pad_same, is_static_pad
10+
from timm.models.layers.padding import get_padding_value
11+
from .padding import pad_same
1212

1313

1414
def conv2d_same(
@@ -31,29 +31,6 @@ def forward(self, x):
3131
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
3232

3333

34-
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
35-
dynamic = False
36-
if isinstance(padding, str):
37-
# for any string padding, the padding will be calculated for you, one of three ways
38-
padding = padding.lower()
39-
if padding == 'same':
40-
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
41-
if is_static_pad(kernel_size, **kwargs):
42-
# static case, no extra overhead
43-
padding = get_padding(kernel_size, **kwargs)
44-
else:
45-
# dynamic 'SAME' padding, has runtime/GPU memory overhead
46-
padding = 0
47-
dynamic = True
48-
elif padding == 'valid':
49-
# 'VALID' padding, same as padding=0
50-
padding = 0
51-
else:
52-
# Default to PyTorch style 'same'-ish symmetric padding
53-
padding = get_padding(kernel_size, **kwargs)
54-
return padding, dynamic
55-
56-
5734
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
5835
padding = kwargs.pop('padding', '')
5936
kwargs.setdefault('bias', False)

timm/models/layers/padding.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Hacked together by Ross Wightman
44
"""
55
import math
6-
from typing import List
6+
from typing import List, Tuple
77

88
import torch.nn.functional as F
99

@@ -25,9 +25,32 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
2525

2626

2727
# Dynamically pad input x with 'SAME' padding for conv with specified args
28-
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
28+
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
2929
ih, iw = x.size()[-2:]
3030
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
3131
if pad_h > 0 or pad_w > 0:
32-
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
32+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
3333
return x
34+
35+
36+
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
37+
dynamic = False
38+
if isinstance(padding, str):
39+
# for any string padding, the padding will be calculated for you, one of three ways
40+
padding = padding.lower()
41+
if padding == 'same':
42+
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
43+
if is_static_pad(kernel_size, **kwargs):
44+
# static case, no extra overhead
45+
padding = get_padding(kernel_size, **kwargs)
46+
else:
47+
# dynamic 'SAME' padding, has runtime/GPU memory overhead
48+
padding = 0
49+
dynamic = True
50+
elif padding == 'valid':
51+
# 'VALID' padding, same as padding=0
52+
padding = 0
53+
else:
54+
# Default to PyTorch style 'same'-ish symmetric padding
55+
padding = get_padding(kernel_size, **kwargs)
56+
return padding, dynamic

0 commit comments

Comments
 (0)