Skip to content

Commit a09dcc2

Browse files
authored
Merge pull request #8 from rasbt/ch14-softmax
ch14-softmax update
2 parents d4b9be6 + 6bf3808 commit a09dcc2

File tree

4 files changed

+212
-94
lines changed

4 files changed

+212
-94
lines changed

ch14/ch14_part1.ipynb

Lines changed: 204 additions & 90 deletions
Large diffs are not rendered by default.

ch14/ch14_part1.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
298298
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
299299
train=True,
300300
transform=transform,
301-
download=False)
301+
download=True)
302302

303303
mnist_valid_dataset = Subset(mnist_dataset, torch.arange(10000))
304304
mnist_train_dataset = Subset(mnist_dataset, torch.arange(10000, len(mnist_dataset)))
@@ -310,6 +310,8 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
310310

311311

312312

313+
314+
313315
batch_size = 64
314316
torch.manual_seed(1)
315317
train_dl = DataLoader(mnist_train_dataset, batch_size, shuffle=True)
@@ -367,8 +369,6 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
367369
model.add_module('dropout', nn.Dropout(p=0.5))
368370

369371
model.add_module('fc2', nn.Linear(1024, 10))
370-
model.add_module('softmax', nn.Softmax(dim=1))
371-
372372

373373

374374

@@ -430,6 +430,8 @@ def train(model, num_epochs, train_dl, valid_dl):
430430

431431

432432

433+
434+
433435
x_arr = np.arange(len(hist[0])) + 1
434436

435437
fig = plt.figure(figsize=(12, 4))
@@ -446,6 +448,7 @@ def train(model, num_epochs, train_dl, valid_dl):
446448
ax.set_xlabel('Epoch', size=15)
447449
ax.set_ylabel('Accuracy', size=15)
448450

451+
#plt.savefig('figures/14_13.png')
449452
plt.show()
450453

451454

@@ -456,7 +459,6 @@ def train(model, num_epochs, train_dl, valid_dl):
456459
pred = model(mnist_test_dataset.data.unsqueeze(1) / 255.)
457460
is_correct = (torch.argmax(pred, dim=1) == mnist_test_dataset.targets).float()
458461
print(f'Test accuracy: {is_correct.mean():.4f}')
459-
460462

461463

462464

@@ -475,6 +477,8 @@ def train(model, num_epochs, train_dl, valid_dl):
475477
verticalalignment='center',
476478
transform=ax.transAxes)
477479

480+
481+
plt.savefig('figures/14_14.png')
478482
plt.show()
479483

480484

ch14/figures/14_13.png

-198 KB
Loading

ch14/figures/14_14.png

12 KB
Loading

0 commit comments

Comments
 (0)