@@ -298,7 +298,7 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
298
298
mnist_dataset = torchvision .datasets .MNIST (root = image_path ,
299
299
train = True ,
300
300
transform = transform ,
301
- download = False )
301
+ download = True )
302
302
303
303
mnist_valid_dataset = Subset (mnist_dataset , torch .arange (10000 ))
304
304
mnist_train_dataset = Subset (mnist_dataset , torch .arange (10000 , len (mnist_dataset )))
@@ -310,6 +310,8 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
310
310
311
311
312
312
313
+
314
+
313
315
batch_size = 64
314
316
torch .manual_seed (1 )
315
317
train_dl = DataLoader (mnist_train_dataset , batch_size , shuffle = True )
@@ -367,8 +369,6 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
367
369
model .add_module ('dropout' , nn .Dropout (p = 0.5 ))
368
370
369
371
model .add_module ('fc2' , nn .Linear (1024 , 10 ))
370
- model .add_module ('softmax' , nn .Softmax (dim = 1 ))
371
-
372
372
373
373
374
374
@@ -430,6 +430,8 @@ def train(model, num_epochs, train_dl, valid_dl):
430
430
431
431
432
432
433
+
434
+
433
435
x_arr = np .arange (len (hist [0 ])) + 1
434
436
435
437
fig = plt .figure (figsize = (12 , 4 ))
@@ -446,6 +448,7 @@ def train(model, num_epochs, train_dl, valid_dl):
446
448
ax .set_xlabel ('Epoch' , size = 15 )
447
449
ax .set_ylabel ('Accuracy' , size = 15 )
448
450
451
+ #plt.savefig('figures/14_13.png')
449
452
plt .show ()
450
453
451
454
@@ -456,7 +459,6 @@ def train(model, num_epochs, train_dl, valid_dl):
456
459
pred = model (mnist_test_dataset .data .unsqueeze (1 ) / 255. )
457
460
is_correct = (torch .argmax (pred , dim = 1 ) == mnist_test_dataset .targets ).float ()
458
461
print (f'Test accuracy: { is_correct .mean ():.4f} ' )
459
-
460
462
461
463
462
464
@@ -475,6 +477,8 @@ def train(model, num_epochs, train_dl, valid_dl):
475
477
verticalalignment = 'center' ,
476
478
transform = ax .transAxes )
477
479
480
+
481
+ plt .savefig ('figures/14_14.png' )
478
482
plt .show ()
479
483
480
484
0 commit comments