@@ -288,9 +288,6 @@ def forward(self: _DenseNet, signal: torch.Tensor) -> torch.Tensor:
288
288
289
289
290
290
class _Hook :
291
- def __init__ (self : _Hook ) -> None :
292
- self .outputs : list [nn .Module ] = []
293
-
294
291
def __call__ (
295
292
self : _Hook ,
296
293
_ : nn .Module ,
@@ -299,6 +296,9 @@ def __call__(
299
296
) -> None :
300
297
self .outputs .append (module_out )
301
298
299
+ def __init__ (self : _Hook ) -> None :
300
+ self .outputs : list [nn .Module ] = []
301
+
302
302
303
303
class _LeNet2D (nn .Module ):
304
304
def __init__ (self : _LeNet2D ) -> None :
@@ -497,6 +497,12 @@ def __init__(
497
497
498
498
499
499
class _UCIEpilepsy (Dataset [tuple [torch .Tensor , torch .Tensor ]]):
500
+ def __getitem__ (
501
+ self : _UCIEpilepsy ,
502
+ index : int ,
503
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
504
+ return (self .data [index ], self .target [index ])
505
+
500
506
def __init__ (
501
507
self : _UCIEpilepsy ,
502
508
num_samples : int ,
@@ -543,12 +549,6 @@ def __init__(
543
549
)
544
550
self .data .unsqueeze_ (1 )
545
551
546
- def __getitem__ (
547
- self : _UCIEpilepsy ,
548
- index : int ,
549
- ) -> tuple [torch .Tensor , torch .Tensor ]:
550
- return (self .data [index ], self .target [index ])
551
-
552
552
def __len__ (self : _UCIEpilepsy ) -> int :
553
553
return self .target .shape [0 ]
554
554
@@ -640,120 +640,6 @@ def _densenet201(num_classes: int) -> nn.Module:
640
640
)
641
641
642
642
643
- def _make_layers (cfg : list ) -> nn .Module : # type: ignore[type-arg]
644
- layers : list [nn .Module ] = []
645
- in_channels = 1
646
- for cfg_element in cfg :
647
- if cfg_element == "M" :
648
- layers += [nn .MaxPool1d (kernel_size = 2 , stride = 2 )]
649
- else :
650
- conv1d = nn .Conv1d (in_channels , cfg_element , kernel_size = 3 , padding = 1 )
651
- layers += [conv1d , nn .ReLU ()]
652
- in_channels = cfg_element
653
- return nn .Sequential (* layers )
654
-
655
-
656
- def _replace_last_layer (
657
- num_classes : int ,
658
- model_base : nn .Module ,
659
- model_file_name : str ,
660
- ) -> nn .Module :
661
- if model_file_name .startswith (("alexnet" , "vgg" )):
662
- model_base .classifier [- 1 ] = nn .Linear (
663
- model_base .classifier [- 1 ].in_features ,
664
- num_classes ,
665
- )
666
- elif model_file_name .startswith ("resnet" ):
667
- model_base .fc = nn .Linear (model_base .fc .in_features , num_classes )
668
- elif model_file_name .startswith ("densenet" ):
669
- model_base .classifier = nn .Linear (
670
- model_base .classifier .in_features ,
671
- num_classes ,
672
- )
673
- return model_base
674
-
675
-
676
- def _resnet101 (num_classes : int ) -> nn .Module :
677
- return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 4 , 23 , 3 ])
678
-
679
-
680
- def _resnet152 (num_classes : int ) -> nn .Module :
681
- return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 8 , 36 , 3 ])
682
-
683
-
684
- def _resnet18 (num_classes : int ) -> nn .Module :
685
- return _ResNet (_BasicBlock , num_classes , expansion = 1 , layers = [2 , 2 , 2 , 2 ])
686
-
687
-
688
- def _resnet34 (num_classes : int ) -> nn .Module :
689
- return _ResNet (_BasicBlock , num_classes , expansion = 1 , layers = [3 , 4 , 6 , 3 ])
690
-
691
-
692
- def _resnet50 (num_classes : int ) -> nn .Module :
693
- return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 4 , 6 , 3 ])
694
-
695
-
696
- def _vgg11 (num_classes : int ) -> nn .Module :
697
- cfg = [64 , "M" , 128 , "M" , 256 , 256 , "M" , 512 , 512 , "M" , 512 , 512 , "M" ]
698
- return _VGG (num_classes , _make_layers (cfg ))
699
-
700
-
701
- def _vgg13 (num_classes : int ) -> nn .Module :
702
- cfg = [64 , 64 , "M" , 128 , 128 , "M" , 256 , 256 , "M" , 512 , 512 , "M" , 512 , 512 , "M" ]
703
- return _VGG (num_classes , _make_layers (cfg ))
704
-
705
-
706
- def _vgg16 (num_classes : int ) -> nn .Module :
707
- cfg = [
708
- 64 ,
709
- 64 ,
710
- "M" ,
711
- 128 ,
712
- 128 ,
713
- "M" ,
714
- 256 ,
715
- 256 ,
716
- 256 ,
717
- "M" ,
718
- 512 ,
719
- 512 ,
720
- 512 ,
721
- "M" ,
722
- 512 ,
723
- 512 ,
724
- 512 ,
725
- "M" ,
726
- ]
727
- return _VGG (num_classes , _make_layers (cfg ))
728
-
729
-
730
- def _vgg19 (num_classes : int ) -> nn .Module :
731
- cfg = [
732
- 64 ,
733
- 64 ,
734
- "M" ,
735
- 128 ,
736
- 128 ,
737
- "M" ,
738
- 256 ,
739
- 256 ,
740
- 256 ,
741
- 256 ,
742
- "M" ,
743
- 512 ,
744
- 512 ,
745
- 512 ,
746
- 512 ,
747
- "M" ,
748
- 512 ,
749
- 512 ,
750
- 512 ,
751
- 512 ,
752
- "M" ,
753
- ]
754
- return _VGG (num_classes , _make_layers (cfg ))
755
-
756
-
757
643
def _main () -> None : # noqa: C901,PLR0912,PLR0915
758
644
if os .getenv ("STAGE" ):
759
645
num_samples = 11500
@@ -1004,6 +890,120 @@ def _main() -> None: # noqa: C901,PLR0912,PLR0915
1004
890
plt .close ()
1005
891
1006
892
893
+ def _make_layers (cfg : list ) -> nn .Module : # type: ignore[type-arg]
894
+ layers : list [nn .Module ] = []
895
+ in_channels = 1
896
+ for cfg_element in cfg :
897
+ if cfg_element == "M" :
898
+ layers += [nn .MaxPool1d (kernel_size = 2 , stride = 2 )]
899
+ else :
900
+ conv1d = nn .Conv1d (in_channels , cfg_element , kernel_size = 3 , padding = 1 )
901
+ layers += [conv1d , nn .ReLU ()]
902
+ in_channels = cfg_element
903
+ return nn .Sequential (* layers )
904
+
905
+
906
+ def _replace_last_layer (
907
+ num_classes : int ,
908
+ model_base : nn .Module ,
909
+ model_file_name : str ,
910
+ ) -> nn .Module :
911
+ if model_file_name .startswith (("alexnet" , "vgg" )):
912
+ model_base .classifier [- 1 ] = nn .Linear (
913
+ model_base .classifier [- 1 ].in_features ,
914
+ num_classes ,
915
+ )
916
+ elif model_file_name .startswith ("resnet" ):
917
+ model_base .fc = nn .Linear (model_base .fc .in_features , num_classes )
918
+ elif model_file_name .startswith ("densenet" ):
919
+ model_base .classifier = nn .Linear (
920
+ model_base .classifier .in_features ,
921
+ num_classes ,
922
+ )
923
+ return model_base
924
+
925
+
926
+ def _resnet101 (num_classes : int ) -> nn .Module :
927
+ return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 4 , 23 , 3 ])
928
+
929
+
930
+ def _resnet152 (num_classes : int ) -> nn .Module :
931
+ return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 8 , 36 , 3 ])
932
+
933
+
934
+ def _resnet18 (num_classes : int ) -> nn .Module :
935
+ return _ResNet (_BasicBlock , num_classes , expansion = 1 , layers = [2 , 2 , 2 , 2 ])
936
+
937
+
938
+ def _resnet34 (num_classes : int ) -> nn .Module :
939
+ return _ResNet (_BasicBlock , num_classes , expansion = 1 , layers = [3 , 4 , 6 , 3 ])
940
+
941
+
942
+ def _resnet50 (num_classes : int ) -> nn .Module :
943
+ return _ResNet (_Bottleneck , num_classes , expansion = 4 , layers = [3 , 4 , 6 , 3 ])
944
+
945
+
946
+ def _vgg11 (num_classes : int ) -> nn .Module :
947
+ cfg = [64 , "M" , 128 , "M" , 256 , 256 , "M" , 512 , 512 , "M" , 512 , 512 , "M" ]
948
+ return _VGG (num_classes , _make_layers (cfg ))
949
+
950
+
951
+ def _vgg13 (num_classes : int ) -> nn .Module :
952
+ cfg = [64 , 64 , "M" , 128 , 128 , "M" , 256 , 256 , "M" , 512 , 512 , "M" , 512 , 512 , "M" ]
953
+ return _VGG (num_classes , _make_layers (cfg ))
954
+
955
+
956
+ def _vgg16 (num_classes : int ) -> nn .Module :
957
+ cfg = [
958
+ 64 ,
959
+ 64 ,
960
+ "M" ,
961
+ 128 ,
962
+ 128 ,
963
+ "M" ,
964
+ 256 ,
965
+ 256 ,
966
+ 256 ,
967
+ "M" ,
968
+ 512 ,
969
+ 512 ,
970
+ 512 ,
971
+ "M" ,
972
+ 512 ,
973
+ 512 ,
974
+ 512 ,
975
+ "M" ,
976
+ ]
977
+ return _VGG (num_classes , _make_layers (cfg ))
978
+
979
+
980
+ def _vgg19 (num_classes : int ) -> nn .Module :
981
+ cfg = [
982
+ 64 ,
983
+ 64 ,
984
+ "M" ,
985
+ 128 ,
986
+ 128 ,
987
+ "M" ,
988
+ 256 ,
989
+ 256 ,
990
+ 256 ,
991
+ 256 ,
992
+ "M" ,
993
+ 512 ,
994
+ 512 ,
995
+ 512 ,
996
+ 512 ,
997
+ "M" ,
998
+ 512 ,
999
+ 512 ,
1000
+ 512 ,
1001
+ 512 ,
1002
+ "M" ,
1003
+ ]
1004
+ return _VGG (num_classes , _make_layers (cfg ))
1005
+
1006
+
1007
1007
M = TypeVar ("M" , _BasicBlock , _Bottleneck )
1008
1008
1009
1009
0 commit comments