• Stars
    star
    863
  • Rank 52,844 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created about 7 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Some custom dataset examples for PyTorch


PyTorch Custom Dataset Examples

Update after two years: It has been a long time since I have created this repository to guide people who are getting started with pytorch (like myself back then). However, over the course of years and various projects, the way I create my datasets changed many times. I included an additional bare bone dataset here to show what I am currently using. I would like to note that the reason why custom datasets are called custom is because you can shape it in anyway you desire. So, it is only natural that you (the reader) will develop your way of creating custom datasets after working on different projects. Examples presented in this project are not there as the ultimate way of creating them but instead, there to show the flexibility and the possiblity of pytorch datasets. I hope this repository is/was useful in your understanding of pytorch datasets.

There are some official custom dataset examples on PyTorch repo like this but they still seemed a bit obscure to a beginner (like me, back then) so I had to spend some time understanding what exactly I needed to have a fully customized dataset. To save you the trouble of going through bajillions of pages, here, I decided to write down the basics of Pytorch datasets. The topics are as follows.

Topics

Custom Dataset Fundamentals

The first and foremost part is creating a dataset class.

from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)

    def __len__(self):
        return count # of how many examples(images?) you have

This is the skeleton that you have to fill to have a custom dataset. A dataset must contain following functions to be used by data loader later on.

  • __init__() function is where the initial logic happens like reading a csv, assigning transforms, filtering data, etc.
  • __getitem__() function returns the data and labels. This function is called from dataloader like this:
img, label = MyCustomDataset.__getitem__(99)  # For 99th item

Here, MyCustomDataset returns two things, an image and a label but that does not mean that __getitem__() is only restricted to return those. Depending on your application you can return many things.

An important thing to note is that __getitem__() returns a specific type for a single data point (like a tensor, numpy array etc.), otherwise, in the data loader you will get an error like:

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>

  • __len__() returns count of samples you have.

Using Torchvision Transforms

In most of the examples you see transforms = None in the __init__(), this is used to apply torchvision transforms to your data/image. You can find the extensive list of the transforms here and here. The most common usage of transforms is like this:

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # stuff
        ...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        if self.transforms is not None:
            data = self.transforms(data)
        # If the transform variable is not empty
        # then it applies the operations in the transforms with the order that it is created.
        return (img, label)

    def __len__(self):
        return count # of how many data(images?) you have
        
if __name__ == '__main__':
    # Define transforms (1)
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # Call the dataset
    custom_dataset = MyCustomDataset(..., transformations)
    

Another Way to Use Torchvision Transforms

Personally, I don't like having dataset transforms outside the dataset class (see (1) above). So, instead of using transforms like the example above, you can also use it like:

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # (2) One way to do it is define transforms individually
        # When you define the transforms it calls __init__() of the transform
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # (3) Or you can still compose them like 
        self.transformations = \
            transforms.Compose([transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        
        # When you call the transform for the second time it calls __call__() and applies the transform 
        data = self.center_crop(data)  # (2)
        data = self.to_tensor(data)  # (2)
        
        # Or you can call the composed version
        data = self.transformations(data)  # (3)
        
        # Note that you only need one of the implementations, (2) or (3)
        return (img, label)

    def __len__(self):
        return count # of how many data(images?) you have
        
if __name__ == '__main__':
    # Call the dataset
    custom_dataset = MyCustomDataset(...)
    

Incorporating Pandas

Let's say we want to read some data from a csv with pandas. The first example is of having a csv file like following (without the headers, even though it really doesn't matter), that contains file name, label(class) and an extra operation indicator and depending on this extra operation flag we do some operation on the image.

File Name Label Extra Operation
tr_0.png 5 TRUE
tr_1.png 0 FALSE
tr_1.png 4 FALSE

If we want to build a custom dataset that reads image locations form this csv file then we can do something like following.

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        some_operation = self.operation_arr[index]
        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    custom_mnist_from_images =  \
        CustomDatasetFromImages('../data/mnist_labels.csv')

Incorporating Pandas with More Logic

Yet another example might be reading an image from CSV where the value of each pixel is listed in a column. (Sometimes MNIST is given this way). This just changes the logic in __getitem__(). In the end, you just return images as tensors and their labels. The data is divided into pixels like

Label pixel_1 pixel_2 ...
1 50 99 ...
0 21 223 ...
9 44 112 ...
class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        """
        Args:
            csv_path (string): path to csv file
            height (int): image height
            width (int): image width
            transform: pytorch transforms for transforms and tensor conversion
        """
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform

    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
	# Convert image from numpy array to PIL image, mode 'L' is for grayscale
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert('L')
        # Transform image to tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # Return image and the label
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return len(self.data.index)
        

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = \
        CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)
        

A Custom-custom-custom Dataset

It has been a long time since I have updated this repository (huh 2 years) and during that time I have completely stopped using torchvision transforms and also csvs and whatnot (unless it is absolutely necessary). One reason I have stopped using torchivion transforms is because I have coded my own transforms but more importantly, I disliked the way transforms are often given as an argument in the dataset class when they are initialized in most of the best-practice examples, when it is not the best way of doing things. More often than not, for the training datasets I have coded over time, I had to use some form of preprocessing operation (flip, mirror, pad, add noise, saturate, crop, ...) randomly and wanted to have the freedom of choosing the degree I apply them, or not. So, my datasets often have a flow like below:

1- Get the location of the image/data to read 
2- Read the image/data
3- Convert to numpy
4- Do some processing on numpy array (randomly)
5- Do some processing on numpy array (randomly)
6- Do some processing on numpy array (randomly)
...
15- Convert data to tensor
return location of data, name of data, data, and label

You can obviously apply transforms just like I listed above too, in the end, it is a matter of taste.

Below, I'm sharing a barebone custom dataset that I'm using for most of my experiments.

class CustomDatasetFromFile(Dataset):
    def __init__(self, folder_path):
        """
        A dataset example where the class is embedded in the file names
        This data example also does not use any torch transforms

        Args:
            folder_path (string): path to image folder
        """
        # Get image list
        self.image_list = glob.glob(folder_path+'*')
        # Calculate len
        self.data_len = len(self.image_list)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_path = self.image_list[index]
        # Open image
        im_as_im = Image.open(single_image_path)

        # Do some operations on image
        # Convert to numpy, dim = 28x28
        im_as_np = np.asarray(im_as_im)/255
        # Add channel dimension, dim = 1x28x28
        # Note: You do not need to do this if you are reading RGB images
        # or i there is already channel dimension
        im_as_np = np.expand_dims(im_as_np, 0)
        # Some preprocessing operations on numpy array
        # ...
        # ...
        # ...

        # Transform image to tensor, change data type
        im_as_ten = torch.from_numpy(im_as_np).float()

        # Get label(class) of the image based on the file name
        class_indicator_location = single_image_path.rfind('_c')
        label = int(single_image_path[class_indicator_location+2:class_indicator_location+3])
        return (im_as_ten, label)

    def __len__(self):
        return self.data_len

Using Data Loader

Pytorch DataLoaders just call __getitem__() and wrap them up to a batch. We can technically not use Data Loaders and call __getitem__() one at a time and feed data to the models (even though it is super convenient to use data loader). Continuing from the example above, if we assume there is a custom dataset called CustomDatasetFromCSV then we can call the data loader like:

...
if __name__ == "__main__":
    # Define transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # Define custom dataset
    custom_mnist_from_csv = \
        CustomDatasetFromCSV('../data/mnist_in_csv.csv',
                             28, 28,
                             transformations)
    # Define data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # Feed the data to the model

The firsts argument of the dataloader is the dataset, from there it calls __getitem__() of that dataset. batch_size determines how many individual data points will be wrapped with a single batch. If we assume a single image tensor is of size: 1x28x28 (D:1, H:28, W:28) then, with this dataloader the returned tensor will be 10x1x28x28 (Batch-Depth-Height-Width).

A note on using multi GPU. The way that multi gpu is used with Pytorch data loaders is that, it tries to divide the batches evenly among all GPUs you have. So, if you use batch size that is less than amount of GPUs you have, it won't be able utilize all GPUs.

Future Updates

I will continue updating this repository whenever I find spare time. Below, are some of the stuff I plan to include. Please let me know if you would like to see some other specific examples.

  • A working custom dataset for Imagenet with normalizations etc. Note: Huh, I have been wanting to do this for a long time. I will probably do this and maybe have another dataset for CIFAR too.
  • A custom dataset example for encoder-decoder networks like U-Net. Note: There is already a U-Net dataset my students coded for their project, have a look here if you are interested.

More Repositories

1

pytorch-cnn-visualizations

Pytorch implementation of convolutional neural network visualization techniques
Python
7,798
star
2

pytorch-cnn-adversarial-attacks

Pytorch implementation of convolutional neural network adversarial attack techniques
Python
349
star
3

unsupervised-learning-document-clustering

Document clustering and topic modelling with Python
Python
85
star
4

adaptive-segmentation-mask-attack

Pre-trained model, code, and materials from the paper "Impact of Adversarial Examples on Deep Learning Models for Biomedical Image Segmentation" (MICCAI 2019).
Python
57
star
5

facial-expression-recognition

Predicting facial expressions with machine learning algorithms
Python
47
star
6

singular-value-decomposition-on-images

Singular Value Decomposition on rgb and grayscale images
Python
11
star
7

imagenet-adversarial-image-evaluation

Code and some materials from the papers "Selection of Source Images Heavily Influences the Effectiveness of Adversarial Attacks" (BMVC 2021) and "Evaluating Adversarial Attacks on ImageNet:A Reality Check on Misclassification Classes" (NeurIPS 2021, Workshop track).
Python
11
star
8

regional-adversarial-perturbation

Code and materials for the paper "Regional Image Perturbation Reduces Lp Norms of Adversarial Examples While Maintaining Model-to-model Transferability" (ICML 2020, UDL).
Python
6
star
9

UCL-DSSC-Hackaton

Winning project of 2017 University College London Data Science Hackathon
Jupyter Notebook
6
star
10

mutate-and-observe

Materials for the paper (Bioinformatics) "Mutate and Observe: Utilizing Deep Neural Networks to Investigate the Impact of Mutations on Translation Initiation"
Python
4
star
11

cnn-gifs

1
star