Pytorch Dataset class

Pytorch Dataset class is the most basic class to represent a dataset. In this chapter of the Pytorch Tutorial, you will learn about the Pytorch Dataset class. You will also learn, in brief, about various other classes available in Pytorch for handling various types of datasets.

Note– Throughout the rest of this chapter, dataset will refer to the generic definition of a dataset- a collection of data. Whereas, the Pytorch Dataset class will be expressed as Dataset.

The Dataset class

The most basic class to handle a dataset in Pytorch is the Dataset class. All the other dataset classes in Pytorch inherit from the Dataset class. While Pytorch also provides many other dataset classes to handle various types of datasets. You can also create your own class for handling the dataset by inheriting the Dataset class.

Importing Dataset class

You can import the Dataset class from torch.utils.data module

from torch.utils.data import Dataset

Creating your own Dataset class

To create your own class for handling the dataset, you must inherit the Dataset class and implement the __getitem__() method. If __getitem__() is not implemented, a NotImplementedError is raised. Optionally, you can also implement the __len__() method. However, it is advisable that you implement __len__() because it is used by samplers and dataset loaders(these are used along with dataset classes to handle data and feed it to the neural network).

The __getitem__(idx) should return an item at index idx. The data returned should be a tuple of the form (input, label) where input represents the input features(an n-dimensional tensor) and label is the target variable.

The __len__() should return the length/size of the dataset.

class MyDataset(Dataset):
    def __getitem__(self, idx):
        # your logic to get item at index idx
        # this method is compulsory to implement, else NotImplemetedError is raised
        pass
    def __len__(self):
        # your logic to return the length/size of dataset
        # this method is not compulsory to implement
        pass

Note– Inherit the Dataset class only if your dataset class is supposed to be map-based. That is if your dataset class maps each index to an item.


Other dataset classes

Pytorch offers numerous other classes to handle datasets of various types. These include-

dataset classUtility
IterableDatasetIterableDataset is a sub-class of Dataset used for handling data coming from a stream. Unlike Dataset, it is not a map-based dataset. Hence, instead of implementing the __getitem__() method, you will have to implement __iter__() method. This method should return an iterator that will iterate over the dataset.
You can inherit the IterableDataset class to create your own dataset class that iterates over the data.
ChainDatasetChainDataset is a sub-class of IterableDataset used for efficiently chaining multiple datasets of IterableDataset ‘s . This class might be useful when combining/ chaining existing IterableDataset ‘s from different streams of data.
BufferedShuffleDatasetBufferedShuffleDataset is a sub-class of IterableDataset used for shuffling items of an IterableDataset. This class might be useful when items from an existing IterableDataset need to be shuffled.
TensorDatasetTensorDataset is a sub-class of Dataset used for handling datasets in the form of a tensor.
ConcatDatasetConcatDataset is a sub-class of Dataset used for concatenating multiple existing datasets.
SubsetSubset is a sub-class of Dataset used to create a dataset that is a subset of the original dataset. This subset is created from items at the given indices in the original dataset.