Pytorch数据加载器更改dict返回值

5w9g7ksd  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(119)

给定一个读取JSON文件的Pytorch数据集:

import csv

from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader2, DataLoader

class MyDataset(IterableDataset):
    def __init__(self, jsonfilename):
        self.filename = jsonfilename

    def __iter__(self):
        with open(self.filename) as fin:
            reader = csv.reader(fin)
            headers = next(reader)
            for line in reader:
                yield dict(zip(headers, line))

content = """imagefile,label
train/0/16585.png,0
train/0/56789.png,0"""

with open('myfile.json', 'w') as fout:
    fout.write(content)

ds = MyDataset("myfile.json")

当我循环遍历数据集时,返回值是json的每一行的dict,例如:

ds = MyDataset("myfile.json")

for i in ds:
    print(i)

[out]:

{'imagefile': 'train/0/16585.png', 'label': '0'}
{'imagefile': 'train/0/56789.png', 'label': '0'}

但是当我将数据集读入DataLoader时,它会以列表的形式返回dict的值,而不是返回值本身,例如:

ds = MyDataset("myfile.json")
x = DataLoader(dataset=ds)

for i in x:
    print(i)

[out]:

{'imagefile': ['train/0/16585.png'], 'label': ['0']}
{'imagefile': ['train/0/56789.png'], 'label': ['0']}

Q(第一部分):为什么DataLoader要将dict的值更改为列表?

且还

Q(第二部分):如何在使用DataLoader运行__iter__时,使DataLoader仅返回dict的值而不是值列表?DataLoader中是否有一些参数/选项可用于执行此操作?

fae0ux8s

fae0ux8s1#

原因是torch.utils.data.DataLoader中的默认整理行为,它决定了如何合并批处理中的数据样本。默认情况下,使用torch.utils.data.default_collate整理函数,该函数将Map转换为:
Map[K,V_i] -〉Map[K,默认逐份打印([V_1,V_2,...])]
和字符串为:
字符串-〉字符串(未更改)
请注意,如果在示例中将batch_size设置为2,则会得到:

{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}

作为这些变换的结果。
假设您不需要批处理,您可以通过设置batch_size=None禁用它来获得所需的输出。更多信息请访问:加载批处理和非批处理数据。

6vl6ewon

6vl6ewon2#

详情请参见@GoodDeeds的答案!https://stackoverflow.com/a/73824234/610569

以下答案适用于TL; DR读取器:

问:为什么DataLoader要将dict的值更改为列表?

答:因为有一个隐含的假设,即DataLoader对象的__iter__应该返回一批数据,而不是单个数据。

Q(第二部分):如何在使用DataLoader运行iter时,使DataLoader仅返回dict的值而不是值列表?DataLoader中是否有一些参数/选项可用于执行此操作?

答:由于隐式批量返回行为,最好修改{key: [value1, value2, ...]中的数据返回批量,而不是试图强制DataLoader返回{key: value1}
要更好地理解批处理假设,请尝试batch_size参数:

x = DataLoader(dataset=ds, batch_size=2)

for i in x:
    print(i)

[out]:

{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}

相关问题