PyTorch DataLoader可以从空数据集开始吗?

xmq68pz9  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(159)

我有一个位于deque缓冲区中的数据集,我想用DataLoader从中加载随机批次。缓冲区开始为空。数据将在缓冲区采样之前添加到缓冲区。

self.buffer = deque([], maxlen=capacity)
self.batch_size = batch_size
self.loader = DataLoader(self.buffer, batch_size=batch_size, shuffle=True, drop_last=True)

但是,这会导致以下错误:

File "env/lib/python3.8/site-packages/torch_geometric/loader/dataloader.py", line 78, in __init__
    super().__init__(dataset, batch_size, shuffle,
  File "env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 268, in __init__
    sampler = RandomSampler(dataset, generator=generator)
  File "env/lib/python3.8/site-packages/torch/utils/data/sampler.py", line 102, in __init__
    raise ValueError("num_samples should be a positive integer "
ValueError: num_samples should be a positive integer value, but got num_samples=0

原来RandomSampler类在初始化时检查num_samples是否为正数,这导致了错误。

if not isinstance(self.num_samples, int) or self.num_samples <= 0:
    raise ValueError("num_samples should be a positive integer "
                     "value, but got num_samples={}".format(self.num_samples))

为什么它在这里检查这个,即使RandomSampler * 不 * 支持在运行时改变大小的数据集?
一种解决方法是使用IterableDataset,但我想使用DataLoader的随机播放功能。
你能想到一个很好的方法来使用一个DataLoader与一个deque?非常感谢!

wj8zmpe1

wj8zmpe11#

这里的问题既不在于 deque 的使用,也不在于数据集是动态可增长的。问题是,你从一个大小为零的数据集开始-这是无效的。
最简单的解决方案是从deque中的任意对象开始,然后动态地删除它。

k5hmc34c

k5hmc34c2#

空数据集可以定义如下:

from torch.utils.data import DataLoader, Dataset

class EmptyDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return 0

    def __getitem__(self, index):
        raise IndexError("Empty dataset cannot be indexed")

可以使用shuffle=False创建/使用加载器:

>>> loader = DataLoader(EmptyDataset(), batch_size=16, shuffle=False)

迭代也很好:

for batch in loader:
    print(batch)

shuffle=True呢?

但是,当您指定shuffle=True时,它会中断RandomSampler

>>> loader = DataLoader(EmptyDataset(), batch_size=16, shuffle=True)
ValueError: num_samples should be a positive integer value, but got num_samples=0

作为一种解决方法,您可以创建自己的采样器,并将相关的错误检查从self.num_samples <= 0调整为self.num_samples < 0

from torch.utils.data import RandomSampler

class MyRandomSampler(RandomSampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if not isinstance(self.num_samples, int) or self.num_samples < 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

然后又道:

>>> loader = DataLoader(EmptyDataset(), batch_size=16, sampler=MyRandomSampler)  # shuffle=True)

请注意,迭代这个loader是不起作用的,但是我把这个问题留给读者来解决。(修复对我来说并不明显,因为Sampler代码有点交织在一起。)一个简单的解决方法是覆盖DataLoader.__iter__._get_iterator

相关问题