Dataset 和 DataLoader 的区别是什么

torch.utils.data 中,有两个类,一个是Dataset,另一个是DataLoader

这两个类的主要区别是什么?

Dataset 一般用于读取数据集的基础数据。例如,在 torch 给出的官网说明中,用于展示数据集的数量,以及用于神经网络训练的单个样本。


class FaceLandmarkDataset(Dataset):
  def __init__(self):
    ...
  def __len__(self):
    """表示数据集的数量"""
  def __getitem__(self, idx):
    """返回某个下标的数据组合。例如,如果是图像和标签,应该是 {'img': img, 'label': label} """

与之相对的,DataLoader 则是一个可以并行读取数据的类。一般情况下,不需要进行继承然后改写。

所以我们主要说说怎么用。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)

我们可以看到,因为 DataLoader 的存在,因此实际上我们在实现 Dataset 的时候:

  1. 不需要关注 shuffle: 为了神经网络训练的 batch不聚集在一个地方
  2. 不需要考虑并行读取,因为有 num_worker
  3. 不需要考虑 batch_size

因此,DataLoader 可以比较容易的完成一些数据集处理前的必要工作。

如果使用 lightning,那么还需要进一步了解 DataModule。我会在另外一篇博客中说明。


也可以看看