Dataset: store the samples and their corresponding labels
Dataloader: wraps an iterable around the Dataset
domain-specific libraries, such as TorcText, TorchVision and TorchAudio.
torch.utils.data.Dataset
torch.utils.data.Dataset: base class representing a Dataset.
Map-style datasets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
| import os
import pandas as pd
class CustomDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
# __len__() return the size of the dataset by many Sampler implementations and the default options of DataLoader.
def __len__(self):
return len(self.img_labels)
# __getitem__() supports fetching a data sample for a given key.
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
|
Iterable-style datasets
Dataset libraries
All datasets are subclasses of torch.utils.data.Dataset
- torchvision.datasets: torchvision.datasets provides many built-in datasets, as well as utility classes for building your own datasets.
- torchtext.datasets
- torchaudio.datasets
torch.utils.data.Sampler
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| class AccedingSequenceLengthSampler(Sampler[int]):
def __init__(self, data: List[str]) -> None:
self.data = data
def __len__(self) -> int:
return len(self.data)
def __iter__(self) -> Iterator[int]:
sizes = torch.tensor([len(x) for x in self.data])
yield from torch.argsort(sizes).tolist()
class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
def __init__(self, data: List[str], batch_size: int) -> None:
self.data = data
self.batch_size = batch_size
def __len__(self) -> int:
return (len(self.data) + self.batch_size - 1) // self.batch_size
def __iter__(self) -> Iterator[List[int]]:
sizes = torch.tensor([len(x) for x in self.data])
for batch in torch.chunk(torch.argsort(sizes), len(self)):
yield batch.tolist()
|
torch.utils.data.DataLoader
DataLoader wraps an iterable around the Dataset to enable easy access to the samples.
https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html
Creating a Custom Dataset for your files
A custom Dataset class must implement three functions: init, len, and getitem.
1
2
3
4
5
6
7
|
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
train_features, train_labels = next(iter(train_dataloader))
|
https://docs.pytorch.org/docs/stable/data.html