@@ -82,3 +82,67 @@ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/0fa674e43
82
82
83
83
## Step 3 : Training the model on the dataset
84
84
85
+ We have used SGD as Optimization Algorithm here with learning rate (lr) = 0.003 and momentum = 0.9 as suggested in general sense. [ Typical lr values range from 0.0001 up to 1 and it is upon us to find a suitable value by cross validation
86
+
87
+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L33
88
+
89
+ To calculate the total training time, time module has been used. (Lines 34 and 48)
90
+
91
+ Trial and Error method can be used to find the suitable epoch value, for this code, it has been setup to be 18
92
+
93
+ Overall Training is being done as :
94
+
95
+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L33-L49
96
+
97
+ ## Step 4 : Testing the Model
98
+
99
+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L51-L66
100
+
101
+ ## Step 5 : Saving the model
102
+
103
+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L68
104
+
105
+ ## To View results for any random picture in the dataset, the following code can be used :
106
+
107
+ It also creates a graph displaying the probabilities returned by the model.
108
+
109
+ ``` py
110
+ import numpy as np
111
+ def view_classify (img , ps ):
112
+ ps = ps.cpu().data.numpy().squeeze()
113
+ fig, (ax1, ax2) = plt.subplots(figsize = (6 ,9 ), ncols = 2 )
114
+ ax1.imshow(img.resize_(1 , 28 , 28 ).numpy().squeeze())
115
+ ax1.axis(' off' )
116
+ ax2.barh(np.arange(10 ), ps)
117
+ ax2.set_aspect(0.1 )
118
+ ax2.set_yticks(np.arange(10 ))
119
+ ax2.set_yticklabels(np.arange(10 ))
120
+ ax2.set_title(' Class Probability' )
121
+ ax2.set_xlim(0 , 1.1 )
122
+ plt.tight_layout()
123
+ img,label= train[np.random.randint(0 ,10001 )]
124
+ image= img.view(1 , 784 )
125
+ with tch.no_grad():
126
+ logps = model(image)
127
+ ps = tch.exp(logps)
128
+ probab = list (ps.numpy()[0 ])
129
+ print (" Predicted Digit =" , probab.index(max (probab)))
130
+ view_classify(image.view(1 , 28 , 28 ), ps)
131
+ ```
132
+
133
+ ### Examples :
134
+
135
+ ![ image] ( https://user-images.githubusercontent.com/77016507/225422901-908e96de-629f-4d33-b7ba-819960a97d66.png )
136
+
137
+ ![ image] ( https://user-images.githubusercontent.com/77016507/225423008-3f858a52-2331-48e1-b271-f6d6e25e2d91.png )
138
+
139
+ ![ image] ( https://user-images.githubusercontent.com/77016507/225423232-d0249b38-e191-495d-b9fd-8c32eb20da57.png )
140
+
141
+ ### Model Accuracy : The Accuracy of the model with this code is approximately 97.8% to 98.02% with a training time of aprox. 3.5 to 4 minutes
142
+
143
+ ## Further Improvements :
144
+
145
+ 1 . Working on making graphical representation of useful data such as Loss vs Epoch Number etc.
146
+ 2 . Looking to test with different algorithms to strike a balance between training time and accuracy.
147
+
148
+ ### Contributions, Suggestions, and inputs on graphical representation for better understanding are welcome.
0 commit comments