在PyTorch中使用Pyothon迭代器作为数据集的最佳方法

xxhby3vn  于 2023-04-12  发布在  其他
关注(0)|答案(2)|浏览(173)

PyTorch DataLoader将数据集转换为可迭代对象。我已经有了一个迭代器,它可以生成数据样本,我想用它来训练和测试。我使用迭代器的原因是样本的总数太大,无法存储在内存中。我想批量加载样本进行训练。
最好的方法是什么?我可以不使用自定义DataLoader吗?PyTorch数据加载器不喜欢将迭代器作为输入。下面是我想要做的一个最小示例,它会产生错误“object of type 'generator' has no len()"。

import torch
from torch import nn
from torch.utils.data import DataLoader

def example_iterator():
    for i in range(10):
        yield i
    

BATCH_SIZE = 3
train_dataloader = DataLoader(example_iterator(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)

print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")

我试图从迭代器中获取数据,并利用PyTorch DataLoader的功能。我给出的示例是我想要实现的最小示例,但它会产生错误。

ahy6op9u

ahy6op9u1#

Meh,这是一个生成器而不是迭代器。但无论如何,这里有一个解决方案:

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

def example_iterator():
    for i in range(10):
        yield i

class MyDataset(Dataset):
    def __init__(self):
        self.generator = example_iterator()
        
    def __getitem__(self, idx):
        return next(self.generator)
        
    def __len__(self):
        return 10 #here you have to put the len

BATCH_SIZE = 3
train_dataloader = DataLoader(MyDataset(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)
6qfn3psc

6qfn3psc2#

PyTorch的DataLoader实际上官方支持可迭代数据集,但它必须是torch.utils.data.IterableDataset子类的示例:
可迭代样式的数据集是实现__iter__()协议的IterableDataset的子类的示例,表示可迭代的数据样本
所以你的代码应该写为:

from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, iterable):
        self.iterable = iterable

    def __iter__(self):
        return iter(self.iterable)

...

train_dataloader = DataLoader(MyIterableDataset(example_iterator()),
                              batch_size = BATCH_SIZE,
                              shuffle=False)

相关问题