This code trains a fully connected neural network using PyTorch on the MNIST dataset to recognize handwritten digits (0–9).
Imports PyTorch, neural network tools, data loading utilities, and dataset + transformation tools.
Downloads the MNIST dataset (handwritten digits).
Converts the 28×28 grayscale images into normalized PyTorch tensors (values between 0 and 1).
Prepares batches of data (64 images per batch) for efficient training and testing.
Confirms dataset sizes and inspects shape of image tensors.
X shape: [64, 1, 28, 28] → 64 images, 1 channel, 28×28 size.
Selects GPU (NVIDIA or Apple Silicon) if available, otherwise defaults to CPU.
A fully connected (dense) neural network with:
Input layer: 784 features (flattened 28×28 image)
2 hidden layers: 512 neurons each, ReLU activation
Output layer: 10 neurons (for 10 digit classes)
Creates and moves the model to the selected device.
CrossEntropyLoss for classification.
Adam optimizer for fast convergence.
Loops over each batch:
Moves data to device
Makes predictions
Calculates loss
Backpropagates errors
Updates model weights
Evaluates model on test data:
Computes average loss and accuracy
pred.argmax(1) picks the most likely digit class
Trains the model for 5 epochs while monitoring performance.
Saves the trained model's weights to disk.
Loads the saved model.
Runs inference on 10 individual images from the test set.
Compares predicted and actual digits.