Pytorch Dataset class is the most basic class to represent a dataset. In this chapter of the Pytorch Tutorial, you will learn about the Pytorch Dataset class. You will also learn, in brief, about various other classes available in Pytorch for handling various types of datasets.
Note– Throughout the rest of this chapter, dataset will refer to the generic definition of a dataset- a collection of data. Whereas, the Pytorch Dataset class will be expressed as Dataset
.
The Dataset class
The most basic class to handle a dataset in Pytorch is the Dataset
class. All the other dataset classes in Pytorch inherit from the Dataset
class. While Pytorch also provides many other dataset classes to handle various types of datasets. You can also create your own class for handling the dataset by inheriting the Dataset
class.
Importing Dataset class
You can import the Dataset
class from torch.utils.data
module
from torch.utils.data import Dataset
Creating your own Dataset class
To create your own class for handling the dataset, you must inherit the Dataset
class and implement the __getitem__()
method. If __getitem__()
is not implemented, a NotImplementedError
is raised. Optionally, you can also implement the __len__() method
. However, it is advisable that you implement __len__()
because it is used by samplers and dataset loaders(these are used along with dataset classes to handle data and feed it to the neural network).
The __getitem__(idx)
should return an item at index idx
. The data returned should be a tuple of the form (input, label)
where input
represents the input features(an n-dimensional tensor) and label
is the target variable.
The __len__()
should return the length/size of the dataset.
class MyDataset(Dataset):
def __getitem__(self, idx):
# your logic to get item at index idx
# this method is compulsory to implement, else NotImplemetedError is raised
pass
def __len__(self):
# your logic to return the length/size of dataset
# this method is not compulsory to implement
pass
Note– Inherit the Dataset
class only if your dataset class is supposed to be map-based. That is if your dataset class maps each index to an item.
Other dataset classes
Pytorch offers numerous other classes to handle datasets of various types. These include-
dataset class | Utility |
---|---|
IterableDataset | IterableDataset is a sub-class of Dataset used for handling data coming from a stream. Unlike Dataset, it is not a map-based dataset. Hence, instead of implementing the __getitem__() method, you will have to implement __iter__() method. This method should return an iterator that will iterate over the dataset.You can inherit the IterableDataset class to create your own dataset class that iterates over the data. |
ChainDataset | ChainDataset is a sub-class of IterableDataset used for efficiently chaining multiple datasets of IterableDataset ‘s . This class might be useful when combining/ chaining existing IterableDataset ‘s from different streams of data. |
BufferedShuffleDataset | BufferedShuffleDataset is a sub-class of IterableDataset used for shuffling items of an IterableDataset . This class might be useful when items from an existing IterableDataset need to be shuffled. |
TensorDataset | TensorDataset is a sub-class of Dataset used for handling datasets in the form of a tensor. |
ConcatDataset | ConcatDataset is a sub-class of Dataset used for concatenating multiple existing datasets. |
Subset | Subset is a sub-class of Dataset used to create a dataset that is a subset of the original dataset. This subset is created from items at the given indices in the original dataset. |