YouTip LogoYouTip

Pytorch Datasets

In deep learning tasks, data loading and processing is a crucial part. PyTorch provides powerful data loading and processing tools, mainly including: * **`torch.utils.data.Dataset`**:An abstract class for datasets, requiring customization and implementation of `__len__` (dataset size) and `__getitem__` (retrieving samples by index). * **`torch.utils.data.TensorDataset`**:A tensor-based dataset, suitable for handling data-label pairs, directly supporting batch processing and iteration. * **`torch.utils.data.DataLoader`**:An iterator that wraps Dataset, providing batch processing, data shuffling, multi-threaded loading and other features, facilitating data input for model training. * **`torchvision.datasets.ImageFolder`**:Loads image data from folders, where each subfolder represents a category, suitable for image classification tasks. ### PyTorch Built-in Datasets PyTorch provides many commonly used datasets through the torchvision.datasets module, for example: * **MNIST**: Handwritten digit image dataset, used for image classification tasks. * **CIFAR**: A dataset containing 10 categories with 60,000 32x32 color images, used for image classification tasks. * **COCO**: A dataset for common object detection, segmentation, and keypoint detection, containing over 330k images and 2.5M object instances. * **ImageNet**: Contains over 14 million images, used for image classification and object detection tasks. * **STL-10**: Contains 100k 96x96 color images, used for image classification tasks. * **Cityscapes**: Contains 5,000 finely annotated urban street scene images, used for semantic segmentation tasks. * **SQUAD**: A dataset used for machine reading comprehension tasks. The above datasets can be loaded through functions in the torchvision.datasets module, or other datasets can be loaded through custom methods. ### torchvision and torchtext * **torchvision**: A graphics library that provides APIs and dataset interfaces for image data processing, including dataset loading functions and common image transformations. * **torchtext**: A natural language processing toolkit that provides tools for text data processing and modeling, including data preprocessing and data loading methods. * * * ## torch.utils.data.Dataset Dataset is an abstract class for dataset representation in PyTorch. Custom datasets need to inherit from torch.utils.data.Dataset and override the following two methods: * `__len__`: Returns the size of the dataset. * `__getitem__`: Retrieves a data sample and its label by index. ## Example import torch from torch.utils.data import Dataset # Custom dataset class MyDataset(Dataset): def __init__ (self, data, labels): # Data initialization self.data= data self.labels= labels def __len__ (self): # Return dataset size return len(self.data) def __getitem__ (self, idx): # Return data and label by index sample =self.data label =self.labels return sample, label # Generate sample data data = torch.randn(100,5)# 100 samples, each sample has 5 features labels = torch.randint(0,2,(100,))# 100 labels, values are 0 or 1 # Instantiate dataset dataset = MyDataset(data, labels) # Test dataset print("Dataset size:",len(dataset)) print("Line 0 samples:", dataset) The output is as follows: Dataset size: 100Line 0 samples: (tensor([-0.2006, 0.7304, -1.3911, -0.4408, 1.1447]), tensor(0)) * * * ## torch.utils.data.DataLoader DataLoader is a data loader provided by PyTorch for batch loading datasets. It provides the following features: * **Batch loading**: By setting `batch_size`. * **Data shuffling**: By setting `shuffle=True`. * **Multi-threading acceleration**: By setting `num_workers`. * **Iterative access**: Conveniently access data in batches. ## Example import torch from torch.utils.data import Dataset from torch.utils.data import DataLoader # Custom dataset class MyDataset(Dataset): def __init__ (self, data, labels): # Data initialization self.data= data self.labels= labels def __len__ (self): # Return dataset size return len(self.data) def __getitem__ (self, idx): # Return data and label by index sample =self.data label =self.labels return sample, label # Generate sample data data = torch.randn(100,5)# 100 samples, each sample has 5 features labels = torch.randint(0,2,(100,))# 100 labels, values are 0 or 1 # Instantiate dataset dataset = MyDataset(data, labels) # Instantiate DataLoader dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0) # Iterate through DataLoader for batch_idx,(batch_data, batch_labels)in enumerate(dataloader): print(f"batch {batch_idx + 1}") print("Data:", batch_data) print("Tag
← Pytorch TensorPytorch Recurrent Neural Netwo β†’