import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt
Basic Intro
This blog depicts a basic CNN model used for image classification using SVHN dataset. The objective is to give an overview of how datasets are required to be loaded, how model architectures are to be made using pytorch library, and how overall training works.
# Define transformations for the SVHN dataset
# Convert images to PyTorch Tensors
# Normalize the tensors. The values (0.5, 0.5, 0.5) are standard for normalizing to [-1, 1]
= transforms.Compose(
transform
[transforms.ToTensor(),0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transforms.Normalize((
= torchvision.datasets.SVHN(root='./data', split='train',
train_dataset =True, transform=transform)
download
= torchvision.datasets.SVHN(root='./data', split='test',
val_dataset =True, transform=transform)
download
len(train_dataset), len(val_dataset)
(73257, 26032)
def imshow(img, title=None):
= img / 2 + 0.5 # unnormalize [-1, 1] -> [0, 1]
img = img.numpy()
np_img 1, 2, 0)))
plt.imshow(np.transpose(np_img, (if title:
plt.title(title)"off")
plt.axis(
print("Random samples from SVHN training set:")
= plt.subplots(1, 5, figsize=(12, 3))
fig, axes for i in range(5):
= torch.randint(0, len(train_dataset), (1,)).item()
idx = train_dataset[idx]
img, label 1, 5, i+1)
plt.subplot(
imshow(img)f"Label: {label}")
plt.title("off")
plt.axis( plt.show()
Random samples from SVHN training set:
= 64
BATCH_SIZE = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
train_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
val_loader
print(f"Data loaded. Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}\n")
Data loaded. Training samples: 73257, Validation samples: 26032
= iter(train_loader)
dataiter = next(dataiter)
images, labels
# Show the batch
print("A batch from the DataLoader:")
=8),
imshow(torchvision.utils.make_grid(images, nrow=" ".join(str(label.item()) for label in labels))
title plt.show()
A batch from the DataLoader:
Model must learn to focus on the central digit and ignore the extra context
= ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
classes
print("Step 2: Model Definition")
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# SVHN images are 3x32x32 (3 color channels)
self.features = nn.Sequential(
# Input: 3x32x32
3, 16, kernel_size=3, stride=1, padding=1), # Output: 16x32x32
nn.Conv2d(
nn.ReLU(),=2, stride=2), # Output: 16x16x16
nn.MaxPool2d(kernel_size16, 32, kernel_size=3, stride=1, padding=1), # Output: 32x16x16
nn.Conv2d(
nn.ReLU(),=2, stride=2) # Output: 32x8x8
nn.MaxPool2d(kernel_size
)self.classifier = nn.Sequential(
nn.Flatten(),32 * 8 * 8, 128),
nn.Linear(
nn.ReLU(),128, num_classes)
nn.Linear(
)
def forward(self, x):
= self.features(x)
x = self.classifier(x)
x return x
= SimpleCNN(num_classes=10)
model print("Model architecture:")
print(model)
print("\n")
Step 2: Model Definition
Model architecture:
SimpleCNN(
(features): Sequential(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=2048, out_features=128, bias=True)
(2): ReLU()
(3): Linear(in_features=128, out_features=10, bias=True)
)
)
= nn.CrossEntropyLoss()
criterion = optim.Adam(model.parameters(), lr=0.001)
optimizer
= 20
NUM_EPOCHS for epoch in range(NUM_EPOCHS):
model.train() = 0.0
running_loss for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
# Forward pass
= model(inputs)
outputs = criterion(outputs, labels)
loss
# Backward pass and optimize
loss.backward()
optimizer.step()
+= loss.item()
running_loss
print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {running_loss/len(train_loader):.4f}")
print("Training finished.\n")
Epoch [1/20], Loss: 0.9173
Epoch [2/20], Loss: 0.4778
Epoch [3/20], Loss: 0.4011
Epoch [4/20], Loss: 0.3499
Epoch [5/20], Loss: 0.3119
Epoch [6/20], Loss: 0.2827
Epoch [7/20], Loss: 0.2563
Epoch [8/20], Loss: 0.2327
Epoch [9/20], Loss: 0.2116
Epoch [10/20], Loss: 0.1927
Epoch [11/20], Loss: 0.1746
Epoch [12/20], Loss: 0.1595
Epoch [13/20], Loss: 0.1442
Epoch [14/20], Loss: 0.1314
Epoch [15/20], Loss: 0.1208
Epoch [16/20], Loss: 0.1081
Epoch [17/20], Loss: 0.1003
Epoch [18/20], Loss: 0.0915
Epoch [19/20], Loss: 0.0826
Epoch [20/20], Loss: 0.0747
Training finished.
eval()
model.= []
all_preds = []
all_labels
# `torch.no_grad()` disables gradient calculation for efficiency
with torch.no_grad():
for inputs, labels in val_loader:
= model(inputs)
outputs = torch.max(outputs.data, 1)
_, predicted
all_preds.extend(predicted.numpy())
all_labels.extend(labels.numpy())
= accuracy_score(all_labels, all_preds)
accuracy = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
precision, recall, f1, _
print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Validation Precision: {precision:.4f}")
print(f"Validation Recall: {recall:.4f}")
print(f"Validation F1 Score: {f1:.4f}\n")
Validation Accuracy: 0.8778
Validation Precision: 0.8787
Validation Recall: 0.8778
Validation F1 Score: 0.8778
# model.state_dict()
print("Step 5: Saving and Loading the Model")
= "svhn_cnn_weights.pth"
MODEL_PATH
torch.save(model.state_dict(), MODEL_PATH)print(f"Model saved to {MODEL_PATH}")
Step 5: Saving and Loading the Model
Model saved to svhn_cnn_weights.pth
= SimpleCNN(num_classes=10)
new_model
new_model.load_state_dict(torch.load(MODEL_PATH))print("Model weights loaded into a new instance.\n")
Model weights loaded into a new instance.
print("Step 6: Inference on New Data")
= val_dataset[0]
new_data_point, true_label_idx = new_data_point.unsqueeze(0)
new_data_point
eval()
new_model.
with torch.no_grad():
= new_model(new_data_point)
prediction_logits # Get the class with the highest score
= torch.argmax(prediction_logits, dim=1).item()
predicted_class_idx
print(f"True Label: {classes[true_label_idx]}")
print(f"Predicted Class: {classes[predicted_class_idx]}")
Step 6: Inference on New Data
True Label: 5
Predicted Class: 5