@@ -60,17 +60,6 @@ def forward(self: _Alexnet, signal: torch.Tensor) -> torch.Tensor:
60
60
return output
61
61
62
62
63
- def _conv3x3 (in_planes : int , out_planes : int , stride : int = 1 ) -> nn .Module :
64
- return nn .Conv1d (
65
- in_planes ,
66
- out_planes ,
67
- kernel_size = 3 ,
68
- stride = stride ,
69
- padding = 1 ,
70
- bias = False ,
71
- )
72
-
73
-
74
63
class _BasicBlock (nn .Module ):
75
64
def __init__ (
76
65
self : _BasicBlock ,
@@ -102,10 +91,6 @@ def forward(self: _BasicBlock, signal: torch.Tensor) -> torch.Tensor:
102
91
return output
103
92
104
93
105
- def _conv1x1 (in_planes : int , out_planes : int , stride : int = 1 ) -> nn .Module :
106
- return nn .Conv1d (in_planes , out_planes , kernel_size = 1 , stride = stride , bias = False )
107
-
108
-
109
94
class _Bottleneck (nn .Module ):
110
95
def __init__ (
111
96
self : _Bottleneck ,
@@ -143,26 +128,6 @@ def forward(self: _Bottleneck, signal: torch.Tensor) -> torch.Tensor:
143
128
return output
144
129
145
130
146
- def _replace_last_layer (
147
- num_classes : int ,
148
- model_base : nn .Module ,
149
- model_file_name : str ,
150
- ) -> nn .Module :
151
- if model_file_name .startswith (("alexnet" , "vgg" )):
152
- model_base .classifier [- 1 ] = nn .Linear (
153
- model_base .classifier [- 1 ].in_features ,
154
- num_classes ,
155
- )
156
- elif model_file_name .startswith ("resnet" ):
157
- model_base .fc = nn .Linear (model_base .fc .in_features , num_classes )
158
- elif model_file_name .startswith ("densenet" ):
159
- model_base .classifier = nn .Linear (
160
- model_base .classifier .in_features ,
161
- num_classes ,
162
- )
163
- return model_base
164
-
165
-
166
131
class _CNNOneLayer (nn .Module ):
167
132
def __init__ (
168
133
self : _CNNOneLayer ,
@@ -206,6 +171,24 @@ def forward(self: _CNNTwoLayers, signal: torch.Tensor) -> torch.Tensor:
206
171
return output
207
172
208
173
174
+ class _DenseBlock (nn .Sequential ):
175
+ def __init__ (
176
+ self : _DenseBlock ,
177
+ bn_size : int ,
178
+ growth_rate : int ,
179
+ num_input_features : int ,
180
+ num_layers : int ,
181
+ ) -> None :
182
+ super ().__init__ ()
183
+ for index in range (num_layers ):
184
+ layer = _DenseLayer (
185
+ bn_size ,
186
+ growth_rate ,
187
+ num_input_features + index * growth_rate ,
188
+ )
189
+ self .add_module ("denselayer%d" % (index + 1 ), layer )
190
+
191
+
209
192
class _DenseLayer (nn .Sequential ):
210
193
def __init__ (
211
194
self : _DenseLayer ,
@@ -245,46 +228,6 @@ def forward(self: _DenseLayer, signal: torch.Tensor) -> torch.Tensor:
245
228
return torch .cat ([signal , new_features ], 1 )
246
229
247
230
248
- class _DenseBlock (nn .Sequential ):
249
- def __init__ (
250
- self : _DenseBlock ,
251
- bn_size : int ,
252
- growth_rate : int ,
253
- num_input_features : int ,
254
- num_layers : int ,
255
- ) -> None :
256
- super ().__init__ ()
257
- for index in range (num_layers ):
258
- layer = _DenseLayer (
259
- bn_size ,
260
- growth_rate ,
261
- num_input_features + index * growth_rate ,
262
- )
263
- self .add_module ("denselayer%d" % (index + 1 ), layer )
264
-
265
-
266
- class _Transition (nn .Sequential ):
267
- def __init__ (
268
- self : _Transition ,
269
- num_input_features : int ,
270
- num_output_features : int ,
271
- ) -> None :
272
- super ().__init__ ()
273
- self .add_module ("norm" , nn .BatchNorm1d (num_input_features ))
274
- self .add_module ("relu" , nn .ReLU ())
275
- self .add_module (
276
- "conv" ,
277
- nn .Conv1d (
278
- num_input_features ,
279
- num_output_features ,
280
- kernel_size = 1 ,
281
- stride = 1 ,
282
- bias = False ,
283
- ),
284
- )
285
- self .add_module ("pool" , nn .AvgPool1d (kernel_size = 2 , stride = 2 ))
286
-
287
-
288
231
class _DenseNet (nn .Module ):
289
232
def __init__ (
290
233
self : _DenseNet ,
@@ -399,30 +342,7 @@ def forward(self: _Lenet, signal: torch.Tensor) -> torch.Tensor:
399
342
return output
400
343
401
344
402
- M = TypeVar ("M" , _BasicBlock , _Bottleneck )
403
-
404
-
405
345
class _ResNet (nn .Module ):
406
- def _make_layer ( # noqa: PLR0913
407
- self : _ResNet ,
408
- block : type [M ],
409
- blocks : int ,
410
- expansion : int ,
411
- planes : int ,
412
- stride : int = 1 ,
413
- ) -> nn .Module :
414
- downsample = None
415
- if stride != 1 or self .inplanes != planes * expansion : # type: ignore[has-type]
416
- downsample = nn .Sequential (
417
- _conv1x1 (self .inplanes , planes * expansion , stride ), # type: ignore[has-type]
418
- nn .BatchNorm1d (planes * expansion ),
419
- )
420
- layers = [block (downsample , self .inplanes , planes , stride )] # type: ignore[has-type]
421
- self .inplanes = planes * expansion
422
- for _ in range (1 , blocks ):
423
- layers .append (block (None , self .inplanes , planes )) # noqa: PERF401
424
- return nn .Sequential (* layers )
425
-
426
346
def __init__ (
427
347
self : _ResNet ,
428
348
block : type [M ],
@@ -453,6 +373,26 @@ def __init__(
453
373
nn .init .constant_ (module .weight , 1 )
454
374
nn .init .constant_ (module .bias , 0 )
455
375
376
+ def _make_layer ( # noqa: PLR0913
377
+ self : _ResNet ,
378
+ block : type [M ],
379
+ blocks : int ,
380
+ expansion : int ,
381
+ planes : int ,
382
+ stride : int = 1 ,
383
+ ) -> nn .Module :
384
+ downsample = None
385
+ if stride != 1 or self .inplanes != planes * expansion :
386
+ downsample = nn .Sequential (
387
+ _conv1x1 (self .inplanes , planes * expansion , stride ),
388
+ nn .BatchNorm1d (planes * expansion ),
389
+ )
390
+ layers = [block (downsample , self .inplanes , planes , stride )]
391
+ self .inplanes = planes * expansion
392
+ for _ in range (1 , blocks ):
393
+ layers .append (block (None , self .inplanes , planes )) # noqa: PERF401
394
+ return nn .Sequential (* layers )
395
+
456
396
def forward (self : _ResNet , signal : torch .Tensor ) -> torch .Tensor :
457
397
out = self .conv1 (signal )
458
398
out = self .bn1 (out )
@@ -534,6 +474,28 @@ def forward(self: _Spectrogram, signal: torch.Tensor) -> torch.Tensor:
534
474
return self .model_base (out ) # type: ignore[no-any-return]
535
475
536
476
477
+ class _Transition (nn .Sequential ):
478
+ def __init__ (
479
+ self : _Transition ,
480
+ num_input_features : int ,
481
+ num_output_features : int ,
482
+ ) -> None :
483
+ super ().__init__ ()
484
+ self .add_module ("norm" , nn .BatchNorm1d (num_input_features ))
485
+ self .add_module ("relu" , nn .ReLU ())
486
+ self .add_module (
487
+ "conv" ,
488
+ nn .Conv1d (
489
+ num_input_features ,
490
+ num_output_features ,
491
+ kernel_size = 1 ,
492
+ stride = 1 ,
493
+ bias = False ,
494
+ ),
495
+ )
496
+ self .add_module ("pool" , nn .AvgPool1d (kernel_size = 2 , stride = 2 ))
497
+
498
+
537
499
class _UCIEpilepsy (Dataset [tuple [torch .Tensor , torch .Tensor ]]):
538
500
def __init__ (
539
501
self : _UCIEpilepsy ,
@@ -592,20 +554,6 @@ def __len__(self: _UCIEpilepsy) -> int:
592
554
593
555
594
556
class _VGG (nn .Module ):
595
- def _initialize_weights (self : _VGG ) -> None :
596
- for module in self .modules ():
597
- if isinstance (module , nn .Conv1d ):
598
- nn .init .kaiming_normal_ (
599
- module .weight ,
600
- mode = "fan_out" ,
601
- nonlinearity = "relu" ,
602
- )
603
- if module .bias is not None :
604
- nn .init .constant_ (module .bias , 0 )
605
- elif isinstance (module , nn .Linear ):
606
- nn .init .normal_ (module .weight , 0 , 0.01 )
607
- nn .init .constant_ (module .bias , 0 )
608
-
609
557
def __init__ (self : _VGG , num_classes : int , features : nn .Module ) -> None :
610
558
super ().__init__ ()
611
559
self .features = features
@@ -620,13 +568,42 @@ def __init__(self: _VGG, num_classes: int, features: nn.Module) -> None:
620
568
)
621
569
self ._initialize_weights ()
622
570
571
+ def _initialize_weights (self : _VGG ) -> None :
572
+ for module in self .modules ():
573
+ if isinstance (module , nn .Conv1d ):
574
+ nn .init .kaiming_normal_ (
575
+ module .weight ,
576
+ mode = "fan_out" ,
577
+ nonlinearity = "relu" ,
578
+ )
579
+ if module .bias is not None :
580
+ nn .init .constant_ (module .bias , 0 )
581
+ elif isinstance (module , nn .Linear ):
582
+ nn .init .normal_ (module .weight , 0 , 0.01 )
583
+ nn .init .constant_ (module .bias , 0 )
584
+
623
585
def forward (self : _VGG , signal : torch .Tensor ) -> torch .Tensor :
624
586
out = self .features (signal )
625
587
out = out .view (out .size (0 ), - 1 )
626
588
output : torch .Tensor = self .classifier (out )
627
589
return output
628
590
629
591
592
+ def _conv1x1 (in_planes : int , out_planes : int , stride : int = 1 ) -> nn .Module :
593
+ return nn .Conv1d (in_planes , out_planes , kernel_size = 1 , stride = stride , bias = False )
594
+
595
+
596
+ def _conv3x3 (in_planes : int , out_planes : int , stride : int = 1 ) -> nn .Module :
597
+ return nn .Conv1d (
598
+ in_planes ,
599
+ out_planes ,
600
+ kernel_size = 3 ,
601
+ stride = stride ,
602
+ padding = 1 ,
603
+ bias = False ,
604
+ )
605
+
606
+
630
607
def _densenet121 (num_classes : int ) -> nn .Module :
631
608
return _DenseNet (
632
609
block_config = (6 , 12 , 24 , 16 ),
@@ -676,6 +653,26 @@ def _make_layers(cfg: list) -> nn.Module: # type: ignore[type-arg]
676
653
return nn .Sequential (* layers )
677
654
678
655
656
+ def _replace_last_layer (
657
+ num_classes : int ,
658
+ model_base : nn .Module ,
659
+ model_file_name : str ,
660
+ ) -> nn .Module :
661
+ if model_file_name .startswith (("alexnet" , "vgg" )):
662
+ model_base .classifier [- 1 ] = nn .Linear (
663
+ model_base .classifier [- 1 ].in_features ,
664
+ num_classes ,
665
+ )
666
+ elif model_file_name .startswith ("resnet" ):
667
+ model_base .fc = nn .Linear (model_base .fc .in_features , num_classes )
668
+ elif model_file_name .startswith ("densenet" ):
669
+ model_base .classifier = nn .Linear (
670
+ model_base .classifier .in_features ,
671
+ num_classes ,
672
+ )
673
+ return model_base
674
+
675
+
679
676
def _resnet101 (num_classes : int ) -> nn .Module :
680
677
return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 4 , 23 , 3 ])
681
678
@@ -1007,5 +1004,8 @@ def _main() -> None: # noqa: C901,PLR0912,PLR0915
1007
1004
plt .close ()
1008
1005
1009
1006
1007
+ M = TypeVar ("M" , _BasicBlock , _Bottleneck )
1008
+
1009
+
1010
1010
if __name__ == "__main__" :
1011
1011
_main ()
0 commit comments