@@ -35,11 +35,8 @@ def normalize(
35
35
Returns:
36
36
np.array, float, float: Normalized data, mean and std.
37
37
"""
38
- if mean is None :
39
- mean = np .mean (data )
40
- if std is None :
41
- std = np .std (data )
42
- return (data - mean ) / std , mean , std
38
+ # TODO: Return the normalized data as well as the mean and the standard deviation.
39
+ return None
43
40
44
41
45
42
class BrainCNN (th .nn .Module ):
@@ -53,24 +50,12 @@ def __init__(self):
53
50
for architectural inspiration.
54
51
"""
55
52
super ().__init__ ()
56
- self .conv1 = nn .Conv1d (in_channels = 44 , out_channels = 64 , kernel_size = 3 )
57
- self .pool1 = nn .MaxPool1d (kernel_size = 2 , stride = 2 )
58
- self .conv2 = nn .Conv1d (in_channels = 64 , out_channels = 128 , kernel_size = 3 )
59
- self .pool2 = nn .MaxPool1d (kernel_size = 2 , stride = 2 )
60
- self .conv3 = nn .Conv1d (in_channels = 128 , out_channels = 128 , kernel_size = 3 )
61
- self .linear = nn .Linear (35456 , 4 )
62
- self .relu = nn .ReLU ()
53
+ # TODO: Implement me!!
63
54
64
55
def forward (self , x ):
65
56
"""Run the forward pass of the network."""
66
- x = self .relu (self .conv1 (x ))
67
- x = self .pool1 (x )
68
- x = self .relu (self .conv2 (x ))
69
- x = self .pool2 (x )
70
- x = self .relu (self .conv3 (x ))
71
- x = th .reshape (x , [x .shape [0 ], - 1 ])
72
- x = self .linear (x )
73
- return x
57
+ # TODO: Return the result of the forward pass instead of 0.
58
+ return 0.
74
59
75
60
76
61
def get_acc (
@@ -88,8 +73,8 @@ def get_acc(
88
73
Returns:
89
74
th.Tensor: The accuracy in [%].
90
75
"""
91
- logits = net ( eeg_input )
92
- accuracy = th . mean (( th . argmax ( logits , - 1 ) == labels ). type ( th . float ))
76
+ # TODO: Compute the correct accuracy.
77
+ accuracy = 0.
93
78
return accuracy
94
79
95
80
@@ -111,51 +96,7 @@ def get_acc(
111
96
low_cut_hz = low_cut_hz ,
112
97
)
113
98
114
- train_set_x , mean , std = normalize (train_set .X )
115
- valid_set_x_np , _ , _ = normalize (valid_set_np .X , mean , std )
116
- test_set_x_np , _ , _ = normalize (test_set_np .X , mean , std )
117
-
118
- train_size = train_set .X .shape [0 ]
119
- train_input = np .array_split (train_set_x , train_size // batch_size )
120
- train_labels = np .array_split (train_set .y , train_size // batch_size )
121
-
122
- valid_set_y = th .tensor (valid_set_np .y )
123
- valid_set_x = th .tensor (valid_set_x_np )
124
- test_set_y = th .tensor (test_set_np .y )
125
- test_set_x = th .tensor (test_set_x_np )
126
-
127
- cnn = BrainCNN ()
128
- opt = th .optim .Adam (cnn .parameters (), lr = 0.001 )
129
- loss = nn .CrossEntropyLoss ()
130
-
131
- val_acc_list = []
132
- for e in range (epochs ):
133
- train_loop = tqdm (
134
- zip (train_input , train_labels ),
135
- total = len (train_input ),
136
- desc = "Training Brain CNN" ,
137
- )
138
- for input_x , labels_y in train_loop :
139
- input_x , _ , _ = normalize (input_x , mean , std )
140
- labels_y = th .tensor (labels_y )
141
- input_x = th .tensor (input_x )
142
-
143
- y_hat = cnn (input_x )
144
- cel = loss (y_hat , labels_y )
145
- cel .backward ()
146
- opt .step ()
147
- opt .zero_grad ()
148
- train_loop .set_description ("Loss: {:2.3f}" .format (cel ))
149
-
150
- val_accuracy = get_acc (cnn , valid_set_x , valid_set_y )
151
- print ("Validation accuracy {:2.3f} at epoch {}" .format (val_accuracy , e + 1 )) # type: ignore
152
- val_acc_list .append (val_accuracy )
153
-
154
- test_accuracy = get_acc (cnn , test_set_x , test_set_y )
155
- print ("Test accuracy: {:2.3f}" .format (test_accuracy )) # type: ignore
156
- plt .plot (val_acc_list , label = "Validation accuracy" )
157
- plt .plot (len (val_acc_list ) - 1 , test_accuracy , "." , label = "Test accuracy" )
158
- plt .xlabel ("epochs" )
159
- plt .ylabel ("accuracy" )
160
- plt .legend ()
161
- plt .show ()
99
+ # Set up Network training with validation and a final test-accuracy measurement.
100
+ # Use PyTorch's Adam optimizer.
101
+ # Use the X and y attributes of the set objects to access the EEG measurements
102
+ # and corresponding labels.
0 commit comments