Skip to content

Commit a7d5a29

Browse files
committed
pytorch brain decode exercise.
1 parent e8b47ab commit a7d5a29

File tree

2 files changed

+11
-76
lines changed

2 files changed

+11
-76
lines changed

src/function.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/train_brain_decoder.py

Lines changed: 11 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,8 @@ def normalize(
3535
Returns:
3636
np.array, float, float: Normalized data, mean and std.
3737
"""
38-
if mean is None:
39-
mean = np.mean(data)
40-
if std is None:
41-
std = np.std(data)
42-
return (data - mean) / std, mean, std
38+
# TODO: Return the normalized data as well as the mean and the standard deviation.
39+
return None
4340

4441

4542
class BrainCNN(th.nn.Module):
@@ -53,24 +50,12 @@ def __init__(self):
5350
for architectural inspiration.
5451
"""
5552
super().__init__()
56-
self.conv1 = nn.Conv1d(in_channels=44, out_channels=64, kernel_size=3)
57-
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
58-
self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3)
59-
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
60-
self.conv3 = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3)
61-
self.linear = nn.Linear(35456, 4)
62-
self.relu = nn.ReLU()
53+
# TODO: Implement me!!
6354

6455
def forward(self, x):
6556
"""Run the forward pass of the network."""
66-
x = self.relu(self.conv1(x))
67-
x = self.pool1(x)
68-
x = self.relu(self.conv2(x))
69-
x = self.pool2(x)
70-
x = self.relu(self.conv3(x))
71-
x = th.reshape(x, [x.shape[0], -1])
72-
x = self.linear(x)
73-
return x
57+
# TODO: Return the result of the forward pass instead of 0.
58+
return 0.
7459

7560

7661
def get_acc(
@@ -88,8 +73,8 @@ def get_acc(
8873
Returns:
8974
th.Tensor: The accuracy in [%].
9075
"""
91-
logits = net(eeg_input)
92-
accuracy = th.mean((th.argmax(logits, -1) == labels).type(th.float))
76+
# TODO: Compute the correct accuracy.
77+
accuracy = 0.
9378
return accuracy
9479

9580

@@ -111,51 +96,7 @@ def get_acc(
11196
low_cut_hz=low_cut_hz,
11297
)
11398

114-
train_set_x, mean, std = normalize(train_set.X)
115-
valid_set_x_np, _, _ = normalize(valid_set_np.X, mean, std)
116-
test_set_x_np, _, _ = normalize(test_set_np.X, mean, std)
117-
118-
train_size = train_set.X.shape[0]
119-
train_input = np.array_split(train_set_x, train_size // batch_size)
120-
train_labels = np.array_split(train_set.y, train_size // batch_size)
121-
122-
valid_set_y = th.tensor(valid_set_np.y)
123-
valid_set_x = th.tensor(valid_set_x_np)
124-
test_set_y = th.tensor(test_set_np.y)
125-
test_set_x = th.tensor(test_set_x_np)
126-
127-
cnn = BrainCNN()
128-
opt = th.optim.Adam(cnn.parameters(), lr=0.001)
129-
loss = nn.CrossEntropyLoss()
130-
131-
val_acc_list = []
132-
for e in range(epochs):
133-
train_loop = tqdm(
134-
zip(train_input, train_labels),
135-
total=len(train_input),
136-
desc="Training Brain CNN",
137-
)
138-
for input_x, labels_y in train_loop:
139-
input_x, _, _ = normalize(input_x, mean, std)
140-
labels_y = th.tensor(labels_y)
141-
input_x = th.tensor(input_x)
142-
143-
y_hat = cnn(input_x)
144-
cel = loss(y_hat, labels_y)
145-
cel.backward()
146-
opt.step()
147-
opt.zero_grad()
148-
train_loop.set_description("Loss: {:2.3f}".format(cel))
149-
150-
val_accuracy = get_acc(cnn, valid_set_x, valid_set_y)
151-
print("Validation accuracy {:2.3f} at epoch {}".format(val_accuracy, e + 1)) # type: ignore
152-
val_acc_list.append(val_accuracy)
153-
154-
test_accuracy = get_acc(cnn, test_set_x, test_set_y)
155-
print("Test accuracy: {:2.3f}".format(test_accuracy)) # type: ignore
156-
plt.plot(val_acc_list, label="Validation accuracy")
157-
plt.plot(len(val_acc_list) - 1, test_accuracy, ".", label="Test accuracy")
158-
plt.xlabel("epochs")
159-
plt.ylabel("accuracy")
160-
plt.legend()
161-
plt.show()
99+
# Set up Network training with validation and a final test-accuracy measurement.
100+
# Use PyTorch's Adam optimizer.
101+
# Use the X and y attributes of the set objects to access the EEG measurements
102+
# and corresponding labels.

0 commit comments

Comments
 (0)