Pytorch Datasets
In deep learning tasks, data loading and processing is a crucial part.
PyTorch provides powerful data loading and processing tools, mainly including:
* **`torch.utils.data.Dataset`**οΌAn abstract class for datasets, requiring customization and implementation of `__len__` (dataset size) and `__getitem__` (retrieving samples by index).
* **`torch.utils.data.TensorDataset`**οΌA tensor-based dataset, suitable for handling data-label pairs, directly supporting batch processing and iteration.
* **`torch.utils.data.DataLoader`**οΌAn iterator that wraps Dataset, providing batch processing, data shuffling, multi-threaded loading and other features, facilitating data input for model training.
* **`torchvision.datasets.ImageFolder`**οΌLoads image data from folders, where each subfolder represents a category, suitable for image classification tasks.
### PyTorch Built-in Datasets
PyTorch provides many commonly used datasets through the torchvision.datasets module, for example:
* **MNIST**: Handwritten digit image dataset, used for image classification tasks.
* **CIFAR**: A dataset containing 10 categories with 60,000 32x32 color images, used for image classification tasks.
* **COCO**: A dataset for common object detection, segmentation, and keypoint detection, containing over 330k images and 2.5M object instances.
* **ImageNet**: Contains over 14 million images, used for image classification and object detection tasks.
* **STL-10**: Contains 100k 96x96 color images, used for image classification tasks.
* **Cityscapes**: Contains 5,000 finely annotated urban street scene images, used for semantic segmentation tasks.
* **SQUAD**: A dataset used for machine reading comprehension tasks.
The above datasets can be loaded through functions in the torchvision.datasets module, or other datasets can be loaded through custom methods.
### torchvision and torchtext
* **torchvision**: A graphics library that provides APIs and dataset interfaces for image data processing, including dataset loading functions and common image transformations.
* **torchtext**: A natural language processing toolkit that provides tools for text data processing and modeling, including data preprocessing and data loading methods.
* * *
## torch.utils.data.Dataset
Dataset is an abstract class for dataset representation in PyTorch.
Custom datasets need to inherit from torch.utils.data.Dataset and override the following two methods:
* `__len__`: Returns the size of the dataset.
* `__getitem__`: Retrieves a data sample and its label by index.
## Example
import torch
from torch.utils.data import Dataset
# Custom dataset
class MyDataset(Dataset):
def __init__ (self, data, labels):
# Data initialization
self.data= data
self.labels= labels
def __len__ (self):
# Return dataset size
return len(self.data)
def __getitem__ (self, idx):
# Return data and label by index
sample =self.data
label =self.labels
return sample, label
# Generate sample data
data = torch.randn(100,5)# 100 samples, each sample has 5 features
labels = torch.randint(0,2,(100,))# 100 labels, values are 0 or 1
# Instantiate dataset
dataset = MyDataset(data, labels)
# Test dataset
print("Dataset size:",len(dataset))
print("Line 0 samples:", dataset)
The output is as follows:
Dataset size: 100Line 0 samples: (tensor([-0.2006, 0.7304, -1.3911, -0.4408, 1.1447]), tensor(0))
* * *
## torch.utils.data.DataLoader
DataLoader is a data loader provided by PyTorch for batch loading datasets.
It provides the following features:
* **Batch loading**: By setting `batch_size`.
* **Data shuffling**: By setting `shuffle=True`.
* **Multi-threading acceleration**: By setting `num_workers`.
* **Iterative access**: Conveniently access data in batches.
## Example
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# Custom dataset
class MyDataset(Dataset):
def __init__ (self, data, labels):
# Data initialization
self.data= data
self.labels= labels
def __len__ (self):
# Return dataset size
return len(self.data)
def __getitem__ (self, idx):
# Return data and label by index
sample =self.data
label =self.labels
return sample, label
# Generate sample data
data = torch.randn(100,5)# 100 samples, each sample has 5 features
labels = torch.randint(0,2,(100,))# 100 labels, values are 0 or 1
# Instantiate dataset
dataset = MyDataset(data, labels)
# Instantiate DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
# Iterate through DataLoader
for batch_idx,(batch_data, batch_labels)in enumerate(dataloader):
print(f"batch {batch_idx + 1}")
print("Data:", batch_data)
print("Tag
YouTip