PyTorch DataLoader将数据集转换为可迭代对象。我已经有了一个迭代器,它可以生成数据样本,我想用它来训练和测试。我使用迭代器的原因是样本的总数太大,无法存储在内存中。我想批量加载样本进行训练。
最好的方法是什么?我可以不使用自定义DataLoader吗?PyTorch数据加载器不喜欢将迭代器作为输入。下面是我想要做的一个最小示例,它会产生错误“object of type 'generator' has no len()"。
import torch
from torch import nn
from torch.utils.data import DataLoader
def example_iterator():
for i in range(10):
yield i
BATCH_SIZE = 3
train_dataloader = DataLoader(example_iterator(),
batch_size = BATCH_SIZE,
shuffle=False)
print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
我试图从迭代器中获取数据,并利用PyTorch DataLoader的功能。我给出的示例是我想要实现的最小示例,但它会产生错误。
2条答案
按热度按时间ahy6op9u1#
Meh,这是一个生成器而不是迭代器。但无论如何,这里有一个解决方案:
6qfn3psc2#
PyTorch的
DataLoader
实际上官方支持可迭代数据集,但它必须是torch.utils.data.IterableDataset
子类的示例:可迭代样式的数据集是实现
__iter__()
协议的IterableDataset的子类的示例,表示可迭代的数据样本所以你的代码应该写为: