我已经通过这个pytorch API通过设置download=True下载了ImageNet数据集。但我不能遍历数据加载器。
错误显示“DataLoader对象不支持索引”
trainset = torch.utils.data.DataLoader(
datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)
我尝试了一个简单的方法,我只是试着运行以下代码,
trainloader[0]
在根目录中,模式为
root/
train/
n01440764/
n01443537/
n01443537_2.jpg
官方网站上的文档没有说别的。https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
我做错了什么?
4条答案
按热度按时间eqqqjvef1#
答案很简单(除了另一个答案中提到的错误)。
DataLoader
没有__getitem__
方法(请参阅源代码)。它用于数据(或数据批)的迭代,而不是随机访问。如果你想访问特定的元素,你应该使用
torch.utils.data.Dataset
,在你的例子中:批量获取
如果你想得到一个批处理,你可以迭代它,然后中断:
DataLoader
以默认或指定的方式创建随机索引(参见samplers),因此没有__getitem__
,因为它对这个对象没有意义。你也可以继承
DataLoader
并创建你自己的__getitem__
函数,做你想做的事情(虽然更复杂)。完整示例
上面应该打印出第一批里面的东西。
gg58donl2#
解决方案
ddhy6vgd3#
torch.utils.data.DataLoader()
的输入数据集的类型应该是torch.utils.data.Dataset
,而不是torch.utils.data.DataLoader
,这就是您在上面代码中所做的。所以,你上面的代码应该是:
有关更多详细信息,请在此处查看官方torch文档。
knsnq2tg4#
最后我想到了这个肮脏的解决方案: