YouTip LogoYouTip

Pytorch Dataset Dataloader

In PyTorch, data processing and loading are critical steps in the deep learning training process. To efficiently handle data, PyTorch provides powerful tools, including torch.utils.data.Dataset and torch.utils.data.DataLoader, which help us manage datasets, batch loading, and data augmentation tasks. Introduction to PyTorch data processing and loading: * **Custom Dataset**: Load your own dataset by inheriting from `torch.utils.data.Dataset`. * **DataLoader**: `DataLoader` loads data in batches, supports multi-threaded loading, and shuffles data. * **Data Preprocessing and Augmentation**: Use `torchvision.transforms` for common image preprocessing and augmentation operations to improve model generalization. * **Loading Standard Datasets**: `torchvision.datasets` provides many common datasets, simplifying the data loading process. * **Multiple Data Sources**: Combine multiple `Dataset` instances to handle data from different sources. ## Custom Dataset torch.utils.data.Dataset is an abstract class that allows you to create datasets from your own data sources. We need to inherit this class and implement the following two methods: * `__len__(self)`: Returns the number of samples in the dataset. * `__getitem__(self, idx)`: Returns a sample by index. Assuming we have a simple CSV file or some list data, we can create our own dataset by inheriting the Dataset class. ## Example import torch from torch.utils.data import Dataset # Custom dataset class class MyDataset(Dataset): def __init__ (self, X_data, Y_data): """ Initialize dataset, X_data and Y_data are two lists or arrays X_data: input features Y_data: target labels """ self.X_data= X_data self.Y_data= Y_data def __len__ (self): """Return dataset size""" return len(self.X_data) def __getitem__ (self, idx): """Return data at specified index""" x = torch.tensor(self.X_data, dtype=torch.float32)# Convert to Tensor y = torch.tensor(self.Y_data, dtype=torch.float32) return x, y # Example data X_data =[[1,2],[3,4],[5,6],[7,8]]# Input features Y_data =[1,0,1,0]# Target labels # Create dataset instance dataset = MyDataset(X_data, Y_data) * * * ## Loading Data with DataLoader DataLoader is an important tool provided by PyTorch for loading data in batches from a Dataset. DataLoader allows us to read data in batches and perform multi-threaded loading, thereby improving training efficiency. ## Example from torch.utils.data import DataLoader # Create DataLoader instance, batch_size sets the number of samples loaded each time dataloader = DataLoader(dataset, batch_size=2, shuffle=True) # Print loaded data for epoch in range(1): for batch_idx,(inputs, labels)in enumerate(dataloader): print(f'Batch {batch_idx + 1}:') print(f'Inputs: {inputs}') print(f'Labels: {labels}') * **`batch_size`**: The number of samples loaded each time. * **`shuffle`**: Whether to shuffle the data, usually needed during training. * **`drop_last`**: If the number of samples in the dataset is not divisible by `batch_size`, set to `True` to drop the last incomplete batch. Output: Batch 1:Inputs: tensor([[3., 4.], [1., 2.]])Labels: tensor([0., 1.])Batch 2:Inputs: tensor([[7., 8.], [5., 6.]])Labels: tensor([0., 1.]) In each iteration, DataLoader returns a batch of data, including input features (inputs) and target labels (labels). * * * ## Preprocessing and Data Augmentation Data preprocessing and augmentation are crucial for improving model performance. PyTorch provides the torchvision.transforms module for common image preprocessing and augmentation operations, such as rotation, cropping, normalization, etc. Common image preprocessing operations: ## Example import torchvision.transforms as transforms from PIL import Image # Define data preprocessing pipeline transform = transforms.Compose([ transforms.Resize((128,128)),# Resize image to 128x128 transforms.ToTensor(),# Convert image to tensor transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])# Normalize ]) # Load image image = Image.open('image.jpg') # Apply preprocessing image_tensor = transform(image) print(image_tensor.shape)# Output tensor shape * **`transforms.Compose()`**: Combine multiple transformation operations together. * **`transforms.Resize()`**: Resize image. * **`transforms.ToTensor()`**: Convert image to PyTorch tensor, values will be normalized to `[0, 1]` range. * **`transforms.Normalize()`**: Normalize image data, usually required when using pre-trained models. ### Image Data Augmentation Data augmentation techniques increase data diversity by applying random transformations to training data, helping models generalize better. For example, random flipping, rotation, cropping, etc. ## Example transform = transforms.Compose([ transforms.RandomHorizontalFlip(),# Random horizontal flip transforms.RandomRotation(30),# Random rotation by 30 degrees transforms.RandomResizedCrop(128),# Random crop and resize to 128x128 transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) These data augmentation methods can be combined using transforms.Compose(), ensuring each image has different transformations during training. * * * ## Loading Image Datasets For image datasets, torchvision.datasets provides many common datasets (such as CIFAR-10, ImageNet, MNIST, etc.) and tools for loading image data. Loading MNIST dataset: ## Example import torchvision.datasets as datasets import torchvision.transforms as transforms # Define preprocessing operations transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))# Normalize grayscale image ]) # Download and load MNIST dataset train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # Create DataLoader train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # Iterate training data for inputs, labels in train_loader: print(inputs.shape)# Shape of input data for each batch print(labels.shape)# Shape of labels for each batch * `datasets.MNIST()` automatically downloads and loads the MNIST dataset. * `transform` parameter allows us to preprocess data. * `train=True` and `train=False` represent training set and test set respectively. * * * ## Using Multiple Data Sources (Multi-source Dataset) If your dataset consists of multiple files, multiple sources (e.g., multiple image folders), you can customize loading multiple data sources by inheriting the Dataset class. PyTorch provides classes like ConcatDataset and ChainDataset to connect multiple datasets. For example, assuming we have data from multiple image folders, we can merge them into one dataset: ## Example from torch.utils.data import ConcatDataset # Assume dataset1 and dataset2 are two Dataset objects combined_dataset = ConcatDataset([dataset1, dataset2]) combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)
← Pytorch CnnPytorch Basic β†’