在 torch.utils.data
中,有两个类,一个是Dataset
,另一个是DataLoader
。
这两个类的主要区别是什么?
Dataset 一般用于读取数据集的基础数据。例如,在 torch 给出的官网说明中,用于展示数据集的数量,以及用于神经网络训练的单个样本。
class FaceLandmarkDataset(Dataset):
def __init__(self):
...
def __len__(self):
"""表示数据集的数量"""
def __getitem__(self, idx):
"""返回某个下标的数据组合。例如,如果是图像和标签,应该是 {'img': img, 'label': label} """
与之相对的,DataLoader 则是一个可以并行读取数据的类。一般情况下,不需要进行继承然后改写。
所以我们主要说说怎么用。
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=0)
我们可以看到,因为 DataLoader 的存在,因此实际上我们在实现 Dataset 的时候:
- 不需要关注 shuffle: 为了神经网络训练的 batch不聚集在一个地方
- 不需要考虑并行读取,因为有
num_worker
- 不需要考虑
batch_size
因此,DataLoader 可以比较容易的完成一些数据集处理前的必要工作。
如果使用 lightning,那么还需要进一步了解 DataModule。我会在另外一篇博客中说明。