Skip to content

Commit 881aa5a

Browse files
authored
Merge pull request #730 from talesa/feature/more_parameterizable_xresnet1dplus
Allow for parameterizing block_szs in XResNet1dPlus
2 parents db55074 + 20ccf5a commit 881aa5a

6 files changed

+12
-12
lines changed

nbs/036_models.InceptionTimePlus.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
" self.seq_len = seq_len\n",
156156
" if custom_head is not None: \n",
157157
" if isinstance(custom_head, nn.Module): head = custom_head\n",
158-
" head = custom_head(self.head_nf, c_out, seq_len)\n",
158+
" else: head = custom_head(self.head_nf, c_out, seq_len)\n",
159159
" else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool, \n",
160160
" fc_dropout=fc_dropout, bn=bn, y_range=y_range)\n",
161161
" \n",

nbs/057_models.MINIROCKETPlus_Pytorch.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
" self.head_nf = num_features\n",
251251
" if custom_head is not None: \n",
252252
" if isinstance(custom_head, nn.Module): head = custom_head\n",
253-
" head = custom_head(self.head_nf, c_out, 1)\n",
253+
" else: head = custom_head(self.head_nf, c_out, 1)\n",
254254
" else:\n",
255255
" layers = [Flatten()]\n",
256256
" if bn:\n",
@@ -506,7 +506,7 @@
506506
" self.head_nf = num_features\n",
507507
" if custom_head is not None: \n",
508508
" if isinstance(custom_head, nn.Module): head = custom_head\n",
509-
" head = custom_head(self.head_nf, c_out, 1)\n",
509+
" else: head = custom_head(self.head_nf, c_out, 1)\n",
510510
" else:\n",
511511
" layers = [Flatten()]\n",
512512
" if bn:\n",

nbs/059_models.XResNet1dPlus.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"class XResNet1dPlus(nn.Sequential):\n",
4646
" @delegates(ResBlock1dPlus)\n",
4747
" def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),\n",
48-
" widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, **kwargs):\n",
48+
" widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):\n",
4949
"\n",
5050
" store_attr('block,expansion,act_cls,ks')\n",
5151
" n_out = c_out or n_out # added for compatibility\n",
@@ -55,14 +55,14 @@
5555
" act=act_cls)\n",
5656
" for i in range(3)]\n",
5757
"\n",
58-
" block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]\n",
58+
" block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]\n",
5959
" block_szs = [64//expansion] + block_szs\n",
6060
" blocks = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)\n",
6161
" backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)\n",
6262
" self.head_nf = block_szs[-1]*expansion\n",
6363
" if custom_head is not None: \n",
6464
" if isinstance(custom_head, nn.Module): head = custom_head\n",
65-
" head = custom_head(self.head_nf, n_out, seq_len)\n",
65+
" else: head = custom_head(self.head_nf, n_out, seq_len)\n",
6666
" else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))\n",
6767
" super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
6868
" self._init_cnn(self)\n",

tsai/models/InceptionTimePlus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self, c_in, c_out, seq_len=None, nf=32, nb_filters=None,
108108
self.seq_len = seq_len
109109
if custom_head is not None:
110110
if isinstance(custom_head, nn.Module): head = custom_head
111-
head = custom_head(self.head_nf, c_out, seq_len)
111+
else: head = custom_head(self.head_nf, c_out, seq_len)
112112
else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool,
113113
fc_dropout=fc_dropout, bn=bn, y_range=y_range)
114114

tsai/models/MINIROCKETPlus_Pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __init__(self, c_in, c_out, seq_len, num_features=10_000, max_dilations_per_
205205
self.head_nf = num_features
206206
if custom_head is not None:
207207
if isinstance(custom_head, nn.Module): head = custom_head
208-
head = custom_head(self.head_nf, c_out, 1)
208+
else: head = custom_head(self.head_nf, c_out, 1)
209209
else:
210210
layers = [Flatten()]
211211
if bn:
@@ -319,7 +319,7 @@ def __init__(self, c_in, c_out, seq_len, num_features=10_000, max_dilations_per_
319319
self.head_nf = num_features
320320
if custom_head is not None:
321321
if isinstance(custom_head, nn.Module): head = custom_head
322-
head = custom_head(self.head_nf, c_out, 1)
322+
else: head = custom_head(self.head_nf, c_out, 1)
323323
else:
324324
layers = [Flatten()]
325325
if bn:

tsai/models/XResNet1dPlus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class XResNet1dPlus(nn.Sequential):
1515
@delegates(ResBlock1dPlus)
1616
def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),
17-
widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, **kwargs):
17+
widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):
1818

1919
store_attr('block,expansion,act_cls,ks')
2020
n_out = c_out or n_out # added for compatibility
@@ -24,14 +24,14 @@ def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropo
2424
act=act_cls)
2525
for i in range(3)]
2626

27-
block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
27+
block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]
2828
block_szs = [64//expansion] + block_szs
2929
blocks = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)
3030
backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)
3131
self.head_nf = block_szs[-1]*expansion
3232
if custom_head is not None:
3333
if isinstance(custom_head, nn.Module): head = custom_head
34-
head = custom_head(self.head_nf, n_out, seq_len)
34+
else: head = custom_head(self.head_nf, n_out, seq_len)
3535
else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))
3636
super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))
3737
self._init_cnn(self)

0 commit comments

Comments
 (0)