PyTorch Custom Datasets

Zero to Mastery Learn PyTorch for Deep Learning

04. PyTorch Custom Datasets

In the last notebook, notebook 03, we looked at how to build computer vision models on an in-built dataset
in PyTorch (FashionMNIST).

The steps we took are similar across many different problems in machine learning.

Find a dataset, turn the dataset into numbers, build a model (or Jnd an existing model) to Jnd patterns in
those numbers that can be used for prediction.

PyTorch has many built-in datasets used for a wide number of machine learning benchmarks, however,
you'll often want to use your own custom dataset.

What is a custom dataset?

A custom dataset is a collection of data relating to a speciJc problem you're working on.

In essence, a custom dataset can be comprised of almost anything.

For example, if we were building a food image classiJcation app like Nutrify, our custom dataset might be
images of food.

Or if we were trying to build a model to classify whether or not a text-based review on a website was
positive or negative, our custom dataset might be examples of existing customer reviews and their ratings.

Or if we were trying to build a sound classiJcation app, our custom dataset might be sound samples
alongside their sample labels.

Or if we were trying to build a recommendation system for customers purchasing things on our website,
our custom dataset might be examples of products other people have bought.

PyTorch includes many existing functions to load in various custom datasets in the TorchVision ,
TorchText , TorchAudio and TorchRec domain libraries.

But sometimes these existing functions may not be enough.

In that case, we can always subclass and customize it to our liking.

What we're going to cover

We're going to be applying the PyTorch WorkQow we covered in notebook 01 and notebook 02 to a
computer vision problem.

But instead of using an in-built PyTorch dataset, we're going to be using our own dataset of pizza, steak
and sushi images.

The goal will be to load these images and then build a model to train and predict on them.

What we're going to build. We'll use torchvision.datasets as well as our own custom Dataset class to
load in images of food and then we'll build a PyTorch computer vision model to hopefully be able to classify

SpeciJcally, we're going to cover:

Topic Contents

0. Importing PyTorch and Let's get PyTorch loaded and then follow best practice to setup our code to be device-
setting up device- agnostic.
agnostic code

1. Get data We're going to be using our own custom dataset of pizza, steak and sushi images.

2. Become one with the At the beginning of any new machine learning problem, it's paramount to understand the
data (data preparation) data you're working with. Here we'll take some steps to Jgure out what data we have.

3. Transforming data Often, the data you get won't be 100% ready to use with a machine learning model, here
we'll look at some steps we can take to transform our images so they're ready to be used
with a model.

4. Loading data with PyTorch has many in-built data loading functions for common types of data.
ImageFolder (option 1) ImageFolder is helpful if our images are in standard image classiJcation format.

5. Loading image data What if PyTorch didn't have an in-built function to load data with? This is where we can
with a custom Dataset build our own custom subclass of .

6. Other forms of Data augmentation is a common technique for expanding the diversity of your training
transforms (data data. Here we'll explore some of torchvision 's in-built data augmentation functions.

7. Model 0: TinyVGG By this stage, we'll have our data ready, let's build a model capable of Jtting it. We'll also
without data create some training and testing functions for training and evaluating our model.

8. Exploring loss curves Loss curves are a great way to see how your model is training/improving over time.
They're also a good way to see if your model is underPtting or overPtting.

9. Model 1: TinyVGG By now, we've tried a model without, how about we try one with data augmentation?
with data augmentation

10. Compare model Let's compare our different models' loss curves and see which performed better and
results discuss some options for improving performance.

11. Making a prediction Our model is trained to on a dataset of pizza, steak and sushi images. In this section we'll
on a custom image cover how to use our trained model to predict on an image outside of our existing

Where can can you get help?

All of the materials for this course live on GitHub.

If you run into trouble, you can ask a question on the course GitHub Discussions page there too.

And of course, there's the PyTorch documentation and PyTorch developer forums, a very helpful place for
all things PyTorch.

0. Importing PyTorch and setting up device-agnostic code

In [1]: import torch

from torch import nn

# Note: this notebook requires torch >= 1.10.0


Out[1]: '1.12.1+cu113'

And now let's follow best practice and setup device-agnostic code.

Note: If you're using Google Colab, and you don't have a GPU turned on yet, it's now time to turn one on
via Runtime -> Change runtime type -> Hardware accelerator -> GPU . If you do this, your runtime
will likely reset and you'll have to run all of the cells above by going Runtime -> Run before .

In [2]: # Setup device-agnostic code

device = "cuda" if torch.cuda.is_available() else "cpu"

Out[2]: 'cuda'

1. Get data
First thing's Jrst we need some data.

And like any good cooking show, some data has already been prepared for us.

We're going to start small.

Because we're not looking to train the biggest model or use the biggest dataset yet.

Machine learning is an iterative process, start small, get something working and increase when necessary.

The data we're going to be using is a subset of the Food101 dataset.

Food101 is popular computer vision benchmark as it contains 1000 images of 101 different kinds of foods,
totaling 101,000 images (75,750 train and 25,250 test).

Can you think of 101 different foods?

Can you think of a computer program to classify 101 foods?

I can.

A machine learning model!

SpeciJcally, a PyTorch computer vision model like we covered in notebook 03.

Instead of 101 food classes though, we're going to start with 3: pizza, steak and sushi.

And instead of 1,000 images per class, we're going to start with a random 10% (start small, increase when

If you'd like to see where the data came from you see the following resources:

Original Food101 dataset and paper website.

torchvision.datasets.Food101 - the version of the data I downloaded for this notebook.

extras/04_custom_data_creation.ipynb - a notebook I used to format the Food101 dataset to use

for this notebook.

data/ - the zip archive of pizza, steak and sushi images from Food101,
created with the notebook linked above.

Let's write some code to download the formatted data from GitHub.

Note: The dataset we're about to use has been pre-formatted for what we'd like to use it for. However,
you'll often have to format your own datasets for whatever problem you're working on. This is a regular
practice in the machine learning world.

In [3]: import requests

import zipfile
from pathlib import Path

# Setup path to data folder

data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download it and prepare it...

if image_path.is_dir():
print(f"{image_path} directory exists.")
print(f"Did not find {image_path} directory, creating one...")
image_path.mkdir(parents=True, exist_ok=True)

# Download pizza, steak, sushi data

with open(data_path / "", "wb") as f:
request = requests.get("
print("Downloading pizza, steak, sushi data...")

# Unzip pizza, steak, sushi data

with zipfile.ZipFile(data_path / "", "r") as zip_ref:
print("Unzipping pizza, steak, sushi data...")

data/pizza_steak_sushi directory exists.

2. Become one with the data (data preparation)

Dataset downloaded!

Time to become one with it.

This is another important step before building a model.

As Abraham Lossfunction said...

Data preparation is paramount. Before building a model, become one with the data. Ask: What am I trying to
do here? Source: @mrdbourke Twitter.

What's inspecting the data and becoming one with it?

Before starting a project or building any kind of model, it's important to know what data you're working with.

In our case, we have images of pizza, steak and sushi in standard image classiJcation format.

Image classiJcation format contains separate classes of images in seperate directories titled with a
particular class name.

For example, all images of pizza are contained in the pizza/ directory.

This format is popular across many different image classiJcation benchmarks, including ImageNet (of the
most popular computer vision benchmark datasets).

You can see an example of the storage format below, the images numbers are arbitrary.

pizza_steak_sushi/ <- overall dataset folder

train/ <- training images
pizza/ <- class name as folder name
test/ <- testing images

The goal will be to take this data storage structure and turn it into a dataset usable with PyTorch.

Note: The structure of the data you work with will vary depending on the problem you're working on. But
the premise still remains: become one with the data, then Jnd a way to best turn it into a dataset
compatible with PyTorch.

We can inspect what's in our data directory by writing a small helper function to walk through each of the
subdirectories and count the Jles present.

To do so, we'll use Python's in-built os.walk() .

In [4]: import os
def walk_through_dir(dir_path):
Walks through dir_path returning its contents.
dir_path (str or pathlib.Path): target directory

A print out of:
number of subdiretories in dir_path
number of images (files) in each subdirectory
name of each subdirectory
for dirpath, dirnames, filenames in os.walk(dir_path):
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath

In [5]: walk_through_dir(image_path)

There are 2 directories and 1 images in 'data/pizza_steak_sushi'.

There are 3 directories and 0 images in 'data/pizza_steak_sushi/test'.
There are 0 directories and 19 images in 'data/pizza_steak_sushi/test/steak'.
There are 0 directories and 31 images in 'data/pizza_steak_sushi/test/sushi'.
There are 0 directories and 25 images in 'data/pizza_steak_sushi/test/pizza'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/train'.
There are 0 directories and 75 images in 'data/pizza_steak_sushi/train/steak'.
There are 0 directories and 72 images in 'data/pizza_steak_sushi/train/sushi'.
There are 0 directories and 78 images in 'data/pizza_steak_sushi/train/pizza'.


It looks like we've got about 75 images per training class and 25 images per testing class.

That should be enough to get started.

Remember, these images are subsets of the original Food101 dataset.

You can see how they were created in the data creation notebook.

While we're at it, let's setup our training and testing paths.

In [6]: # Setup train and testing paths

train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

Out[6]: (PosixPath('data/pizza_steak_sushi/train'),

2.1 Visualize an image

Okay, we've seen how our directory structure is formatted.

Now in the spirit of the data explorer, it's time to visualize, visualize, visualize!

Let's write some code to:

1. Get all of the image paths using pathlib.Path.glob() to Jnd all of the Jles ending in .jpg .

2. Pick a random image path using Python's random.choice() .

3. Get the image class name using pathlib.Path.parent.stem .

4. And since we're working with images, we'll open the random image path using (PIL
stands for Python Image Library).

5. We'll then show the image and print some metadata.

In [7]: import random

from PIL import Image

# Set seed
random.seed(42) # <- try changing this and see what happens

# 1. Get all image paths (* means "any combination")

image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. Get random image path

random_image_path = random.choice(image_path_list)

# 3. Get image class from path name (the image class is the name of the directory where the
image_class = random_image_path.parent.stem

# 4. Open image
img =

# 5. Print metadata
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")

Random image path: data/pizza_steak_sushi/test/pizza/2124579.jpg

Image class: pizza
Image height: 384
Image width: 512

We can do the same with matplotlib.pyplot.imshow() , except we have to convert the image to a
NumPy array Jrst.

In [8]: import numpy as np

import matplotlib.pyplot as plt

# Turn the image into an array

img_as_array = np.asarray(img)

# Plot the image with matplotlib

plt.figure(figsize=(10, 7))
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, wid

3. Transforming data
Now what if we wanted to load our image data into PyTorch?

Before we can use our image data with PyTorch we need to:

1. Turn it into tensors (numerical representations of our images).

2. Turn it into a and subsequently a , we'll

call these Dataset and DataLoader for short.

There are several different kinds of pre-built datasets and dataset loaders for PyTorch, depending on the
problem you're working on.

Problem space Pre-built Datasets and Functions

Vision torchvision.datasets

Audio torchaudio.datasets

Text torchtext.datasets

Recommendation system torchrec.datasets

Since we're working with a vision problem, we'll be looking at torchvision.datasets for our data loading
functions as well as torchvision.transforms for preparing our data.

Let's import some base libraries.

In [9]: import torch

from import DataLoader
from torchvision import datasets, transforms

3.1 Transforming data with torchvision.transforms

We've got folders of images but before we can use them with PyTorch, we need to convert them into

One of the ways we can do this is by using the torchvision.transforms module.

torchvision.transforms contains many pre-built methods for formatting images, turning them into
tensors and even manipulating them for data augmentation (the practice of altering data to make it harder
for a model to learn, we'll see this later on) purposes .

To get experience with torchvision.transforms , let's write a series of transform steps that:

1. Resize the images using transforms.Resize() (from about 512x512 to 64x64, the same shape as
the images on the CNN Explainer website).

2. Flip our images randomly on the horizontal using transforms.RandomHorizontalFlip() (this could
be considered a form of data augmentation because it will artiJcially change our image data).

3. Turn our images from a PIL image to a PyTorch tensor using transforms.ToTensor() .

We can compile all of these steps using torchvision.transforms.Compose() .

In [10]: # Write transform for image

data_transform = transforms.Compose([
# Resize the images to 64x64
transforms.Resize(size=(64, 64)),
# Flip the images randomly on the horizontal
transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
# Turn the image into a torch.Tensor
transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be betwee

Now we've got a composition of transforms, let's write a function to try them out on various images.

In [11]: def plot_transformed_images(image_paths, transform, n=3, seed=42):

"""Plots a series of random images from image_paths.

Will open n image paths from image_paths, transform them

with transform and plot them side by side.

image_paths (list): List of target image paths.
transform (PyTorch Transforms): Transforms to apply to images.
n (int, optional): Number of images to plot. Defaults to 3.
seed (int, optional): Random seed for the random generator. Defaults to 42.
random_image_paths = random.sample(image_paths, k=n)
for image_path in random_image_paths:
with as f:
fig, ax = plt.subplots(1, 2)
ax[0].set_title(f"Original \nSize: {f.size}")

# Transform and plot image

# Note: permute() will change shape of image to suit matplotlib
# (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
transformed_image = transform(f).permute(1, 2, 0)
ax[1].set_title(f"Transformed \nSize: {transformed_image.shape}")

fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=16)



We've now got a way to convert our images to tensors using torchvision.transforms .

We also manipulate their size and orientation if needed (some models prefer images of different sizes and

Generally, the larger the shape of the image, the more information a model can recover.

For example, an image of size [256, 256, 3] will have 16x more pixels than an image of size [64, 64,
3] ( (256*256*3)/(64*64*3)=16 ).

However, the tradeoff is that more pixels requires more computations.

Exercise: Try commenting out one of the transforms in data_transform and running the plotting
function plot_transformed_images() again, what happens?

4. Option 1: Loading Image Data Using ImageFolder

Alright, time to turn our image data into a Dataset capable of being used with PyTorch.

Since our data is in standard image classiJcation format, we can use the class
torchvision.datasets.ImageFolder .

Where we can pass it the Jle path of a target image directory as well as a series of transforms we'd like to
perform on our images.

Let's test it out on our data folders train_dir and test_dir passing in transform=data_transform to
turn our images into tensors.

In [12]: # Use ImageFolder to create dataset(s)

from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir, # target folder of images
transform=data_transform, # transforms to perform on dat
target_transform=None) # transforms to perform on labels

