具有多个工作线程的可迭代pytorch数据集

xu3bshqb  于 2023-03-02  发布在  其他
关注(0)|答案(2)|浏览(190)

我有一个比内存还大的文本文件,我想在PyTorch中创建一个数据集,它可以逐行读取,这样我就不必在内存中一次加载它了。我发现pytorch IterableDataset是解决我的问题的潜在方案。它只在使用一个工作线程时才能按预期工作,如果使用多个工作线程,它将创建重复的记录。让我给你看一个例子:
具有testfile.txt,其中包含:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

定义可迭代数据集:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label

    def __iter__(self):

        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        return mapped_itr

我们现在可以测试它:

base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
    print(X,y)

它输出:

('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)

是的。但是如果我把工人的数量改为2,产量就变成

('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)

这是不正确的,因为在数据加载器中为每个工作线程创建了每个样本的副本。
pytorch有办法解决这个问题吗?所以可以创建一个数据加载器,不加载内存中的所有文件,支持多个工作者。

zdwk9cvp

zdwk9cvp1#

所以我在torch讨论论坛https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3中找到了答案,他们指出我应该使用worker info连续切片到批量大小。
新数据集将如下所示:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label

    def __iter__(self):
        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        #Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        return mapped_itr

特别感谢@Ivan,他也指出了切片解决方案。
对于两个工作进程,它只返回与一个工作进程相同的数据

eqqqjvef

eqqqjvef2#

你可以使用torch.utils.data.get_worker_info实用程序访问Dataset__iter__函数中的worker标识符。这意味着你可以遍历迭代器并根据worker * id * 添加一个偏移量。你可以用itertools.islice Package 迭代器,这样你就可以遍历start索引和step
下面是一个最小的例子:

class DS(IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def __iter__(self):
        uid = torch.utils.data.get_worker_info().id
        itr = islice(range(10), uid, None, self.batch_size)
        return itr

即使我们使用的是num_workers > 1,遍历数据加载器也会产生唯一的示例:

>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
...     print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])

在您的情况下,您可以:

def __iter__(self):
        # create an iterator
        file_itr = open(self.filename)

        # map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
    
        # wrap the iterator
        step_itr = islice(mapped_itr, uid, None, self.batch_size)

        return step_itr

相关问题