pytorch “DataLoader”对象不支持索引

x0fgdtte  于 2023-05-17  发布在  其他
关注(0)|答案(4)|浏览(290)

我已经通过这个pytorch API通过设置download=True下载了ImageNet数据集。但我不能遍历数据加载器。
错误显示“DataLoader对象不支持索引”

trainset = torch.utils.data.DataLoader(
    datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
                      download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)

我尝试了一个简单的方法,我只是试着运行以下代码,

trainloader[0]

在根目录中,模式为

root/  
    train/  
          n01440764/
          n01443537/ 
                   n01443537_2.jpg

官方网站上的文档没有说别的。https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
我做错了什么?

eqqqjvef

eqqqjvef1#

答案很简单(除了另一个答案中提到的错误)。

DataLoader没有__getitem__方法(请参阅源代码)。

它用于数据(或数据批)的迭代,而不是随机访问。如果你想访问特定的元素,你应该使用torch.utils.data.Dataset,在你的例子中:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]

批量获取

如果你想得到一个批处理,你可以迭代它,然后中断:

for batch in dataloader:
    print(batch) # or anything else you want to do
    break

DataLoader以默认或指定的方式创建随机索引(参见samplers),因此没有__getitem__,因为它对这个对象没有意义。
你也可以继承DataLoader并创建你自己的__getitem__函数,做你想做的事情(虽然更复杂)。

完整示例

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)

for batch in trainloader:
    print(batch)
    break

上面应该打印出第一批里面的东西。

gg58donl

gg58donl2#

解决方案

input_transform = standard_transforms.Compose([
    transforms.Resize((255,255)), # to Make sure all the 
    transforms.CenterCrop(224),   # imgs are at the same size 
    transforms.ToTensor()
])  

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
                             split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)

for batch_idx, data in enumerate(trainloader, 0):
    x, y = data 
    break
ddhy6vgd

ddhy6vgd3#

torch.utils.data.DataLoader()的输入数据集的类型应该是torch.utils.data.Dataset,而不是torch.utils.data.DataLoader,这就是您在上面代码中所做的。
所以,你上面的代码应该是:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', 
                                          split='train', 
                                          download=False)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=1, 
                                          shuffle=False, 
                                          num_workers=1)

有关更多详细信息,请在此处查看官方torch文档。

knsnq2tg

knsnq2tg4#

最后我想到了这个肮脏的解决方案:

def Dataloader_by_Index(data_loader, target=0):
    for index, data in enumerate(data_loader):
        if index == target:
            return data
    return None
fifth_element = Dataloader_by_Index(my_data_loader, target=4)

相关问题