Skip to content

Commit 992dca4

Browse files
Update README.md
1 parent a014ffa commit 992dca4

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,67 @@ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/0fa674e43
8282

8383
## Step 3 : Training the model on the dataset
8484

85+
We have used SGD as Optimization Algorithm here with learning rate (lr) = 0.003 and momentum = 0.9 as suggested in general sense. [Typical lr values range from 0.0001 up to 1 and it is upon us to find a suitable value by cross validation
86+
87+
https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L33
88+
89+
To calculate the total training time, time module has been used. (Lines 34 and 48)
90+
91+
Trial and Error method can be used to find the suitable epoch value, for this code, it has been setup to be 18
92+
93+
Overall Training is being done as :
94+
95+
https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L33-L49
96+
97+
## Step 4 : Testing the Model
98+
99+
https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L51-L66
100+
101+
## Step 5 : Saving the model
102+
103+
https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L68
104+
105+
## To View results for any random picture in the dataset, the following code can be used :
106+
107+
It also creates a graph displaying the probabilities returned by the model.
108+
109+
```py
110+
import numpy as np
111+
def view_classify(img, ps):
112+
ps = ps.cpu().data.numpy().squeeze()
113+
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
114+
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
115+
ax1.axis('off')
116+
ax2.barh(np.arange(10), ps)
117+
ax2.set_aspect(0.1)
118+
ax2.set_yticks(np.arange(10))
119+
ax2.set_yticklabels(np.arange(10))
120+
ax2.set_title('Class Probability')
121+
ax2.set_xlim(0, 1.1)
122+
plt.tight_layout()
123+
img,label=train[np.random.randint(0,10001)]
124+
image=img.view(1, 784)
125+
with tch.no_grad():
126+
logps = model(image)
127+
ps = tch.exp(logps)
128+
probab = list(ps.numpy()[0])
129+
print("Predicted Digit =", probab.index(max(probab)))
130+
view_classify(image.view(1, 28, 28), ps)
131+
```
132+
133+
### Examples :
134+
135+
![image](https://user-images.githubusercontent.com/77016507/225422901-908e96de-629f-4d33-b7ba-819960a97d66.png)
136+
137+
![image](https://user-images.githubusercontent.com/77016507/225423008-3f858a52-2331-48e1-b271-f6d6e25e2d91.png)
138+
139+
![image](https://user-images.githubusercontent.com/77016507/225423232-d0249b38-e191-495d-b9fd-8c32eb20da57.png)
140+
141+
### Model Accuracy : The Accuracy of the model with this code is approximately 97.8% to 98.02% with a training time of aprox. 3.5 to 4 minutes
142+
143+
## Further Improvements :
144+
145+
1. Working on making graphical representation of useful data such as Loss vs Epoch Number etc.
146+
2. Looking to test with different algorithms to strike a balance between training time and accuracy.
147+
148+
### Contributions, Suggestions, and inputs on graphical representation for better understanding are welcome.

0 commit comments

Comments
 (0)