Image Retrieval Using Pre-trained ResNet in 5 Steps

No one knows I’m a geek
3 min readDec 16, 2020

--

Got an image and wanna find similar images from your image folder or database? Image retrieval can be very useful in a lot applications. For example, a search engine with an image search function, or helping people find an image they are looking for but not quite sure where to start from.

In this article, we’ll use ResNet50 and K-Nearest Neighbours to develop a K-NN classifier. The dataset we use is Cifar-100 collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. (Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009).

For the demo, we created a smaller dataset from the original dataset by randomly selecting 1 image from each class in the test set and 5 images from each class in the training set. Therefore, we have a dataset of 500 training images as candidates and 100 test images. The folder structure looks as follows:

-cifar100small
--train
---apple
---baby
---bear
...
--test
---apple
---baby
---bear
...

This folder structure facilitates data loading with the torchvision ImageFolder API.

Now let’s get started with the exciting part!

  1. Load data
from torchvision import datasets, transformstrain_image_dir = "cifar100small/train" # path to train images
test_image_dir = "cifar100small/test" # path to test images
transform = transforms.Compose([
transforms.Resize((224, 224,)),
transforms.ToTensor(),
])
# Train dataset and dataloader (our candidates)candidates_dataset = datasets.ImageFolder(train_image_dir, transform=transform)candidates_loader = torch.utils.data.DataLoader(candidates_dataset, batch_size=64, shuffle=False, num_workers=2)candidates_path = candidates_dataset.imgs
# Test dataset and dataloader
test_dataset = datasets.ImageFolder(test_image_dir, transform=transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

2. Load pre-trained ResNet Model

import torch
from torchvision import models
# Use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use pretrained ResNet model
net = models.resnext50_32x4d(pretrained=True)
# Create a new network without the fully connected layers
net = torch.nn.Sequential(*(list(net.children())[:-1]))
net = net.to(device)

3. Calculate the features for the candidates

import numpy as npcandidates = []
for data in candidates_loader:
image, label = data
image = image.to(device)
output = net(image)
output = output.detach().cpu().numpy().reshape(-1,
np.prod(output.size()[1:]))
candidates.append(output)candidates = np.concatenate(candidates)

4. Fit the candidates

from sklearn.neighbors import NearestNeighborsk = 10
knn = NearestNeighbors(n_neighbors=k, metric="cosine")
knn.fit(candidates)

5. Start predicting!

def predict(test_loader):
res = []
for i, data in enumerate(test_loader):
image, label = data
image = image.to(device)
output = net(image)
output = output.detach().cpu().numpy().reshape(-1,
np.prod(output.size()[1:]))

_, neighbors_idx = knn.kneighbors(output)
res.append((test_dataset.imgs[i][0], neighbors_idx[0],))
return resres = predict(test_loader)

To visualise the results, you can use this helper function:

import matplotlib.pyplot as plt
from PIL import Image
def plot_result(query_image, candidates_idx, candidates_path):
plt.figure(figsize=(20,10))
columns = 5
# show query image
image = Image.open(query_image)
image = np.asarray(image)
ax = plt.subplot(3, columns, 1)
ax.set_title(f"Query: {query_image.split('/')[-1]}")
plt.imshow(image)
for i, idx in enumerate(candidates_idx):
image_path = candidates_path[idx][0]
image = Image.open(image_path)
image = np.asarray(image)
ax = plt.subplot(3, columns, i + 1 + 5)
ax.set_title(image_path.split("/")[-1])
plt.imshow(image)

And then plot some results!

test_image_1, candidates_idx = res[5]
plot_result(test_image_1, candidates_idx, candidates_path)

Let’s take a look at some results:

Example 1: camel

Top 10 nearest neighbours for query image “camel”.

Example 2: armchair

Top 10 nearest neighbours for query image “armchair”.

Example 3:

Top 10 nearest neighbours for query image “bowl”.

Example 4:

Top 10 nearest neighbours for query image “bed”.

Voilà! The results look quite good, as the nearest neighbours are actually quite similar to the query image. Of course, in some cases it works less well, such as the example with the pink bed. But at least it got a couch as a similar neighbour, which is kind of similar to a bed 😂

If you like my articles, don’t forget to give them a clap 👏 and share them with your friends and colleagues! Cheers! 😉

--

--