MNIST Classification using PyTorch’s Neural Network
1. Introduction
1.1. MNIST Dataset and the Classification Problem
The MNIST dataset is a collection of handwritten digits, each represented as a grayscale image of size 28x28 pixels. The corresponding classification problem involves classifying these images into one of the 10 digits (0 to 9). We treat each image as a vector of dimension and aim to predict the probability distribution over the 10 possible classes, i.e. .
In the following example, we use a neural network to classify MNIST images. As we can see, most of the predictions are correct, except for the images of the digits 2 and 3, which are predicted as 0 and 7, respectively.
1.2. Feedforward Neural Network
Here, we solve this classification problem by using a feedforward neural network (NN).
A neural network is a parametric function , where is a set of parameters. Given a training data set for , the training process a NN is the process of finding so as to minimize the following average loss value:
where is a loss function that measures the discrepancy between the NN’s output and the ground truth output .
Each architecture of NN will correspond to a specific definition of . Assume that , where and , then the hidden layers NN is usually defined as the composition of , i.e.
where is the activation of linear map
where is pre-chosen activation function that applied element-wise to the vector input. Typical choices of activation function include Linear and RELU function,
and
1.3. PyTorch
PyTorch is a popular powerful deep learning framework seamlessly integrated with Python. PyTorch empowers developers and researchers to create, train, and evaluate neural network models from simple feedforward networks to intricate deep learning architectures. These tasks include computer vision, natural language processing, reinforcement learning, and more. In this blog post, we’ll leverage PyTorch’s flexibility to develop and train a neural network for the MNIST classification task.
2. Implementation
The following Python code uses PyTorch to implement and train a NN for MNIST digit classification. The code is adapted from the excellent Youtube video tutorial from Patrick Loeber. The code is available on this GitHub link. Note that you can execute this code using Google Colab without installing anything.
In the following code, theMNIST dataset is automatically downloaded. The batch_size
parameter specifies the number of images that are processed together during each training iteration. The num_epochs
parameter specifies the number of times that the entire training dataset is passed through the NN during the training process.
We define the NN as a class that inherits from PyTorch’s nn.Module
. In the __init__()
method, we define the NN’s layers. Here, we consider a shallow NN with a single hidden layer, i.e. . We then create the forward()
method, which defines the NN’s function with :
In the training step, we use Cross-Entropy-Loss function and ADAM solving method. In this example code, the final accuracy on the test dataset should be at least 90%.
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
input_size = 784 # 28x28
hidden_size = 500
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# examples = iter(test_loader)
# example_data, example_targets = next(examples)
# for i in range(6):
# plt.subplot(2,3,i+1)
# plt.imshow(example_data[i][0], cmap='gray')
# plt.show()
# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.input_size = input_size
self.l1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.l2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
# no activation and no softmax at the end
return out
model = NeuralNet(input_size, hidden_size, num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# origin shape: [100, 1, 28, 28]
# resized: [100, 784]
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
n_correct = 0
n_samples = 0
for images, labels in test_loader:
flat_images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(flat_images)
# max returns (value ,index)
_, predicted = torch.max(outputs.data, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
acc = 100.0 * n_correct / n_samples
print(f'Accuracy of the network on the 10000 test images: {acc} %')
for i in range(1, 9):
lab = labels[i].item()
pred = predicted[i].item()
plt.subplot(2,4,i)
plt.imshow(images[i][0])
plt.title(f"prediction={pred}" if i==1 else f"{pred}")
plt.axis("off")
plt.show()