pytorch Python中的顺序yield调用的延迟

zfciruhq  于 2024-01-09  发布在  Python
关注(0)|答案(1)|浏览(180)

我写了一段代码来读取存储在h5文件中的一组pandas.DataFrames,并在它们的行上进行遍历。我的代码的目的是使用pytorchIterableDataset来处理数据集,但我相信我下面的问题并不局限于pytorch
由于从磁盘阅读每个h5需要一段时间,我实现了以下逻辑
1.代码从磁盘读取第一个文件,并异步预取第二个文件
1.一旦从磁盘中读取了第一个DataFrame,代码就开始迭代它的行(使用`yield from)
1.一旦迭代完成,代码就异步读取下面的文件,并开始从之前预取的文件中产生。
代码可以在这里找到

def _load_next( file_list, index, device, labels, variables):
    if index >= len(file_list):
        return None
        
     thedata=pd.read_hdf(file_list[index], 'df')
     labels=torch.Tensor( thedata[labels].values).to(device)
     variables=torch.Tensor( thedata[variables].values).to(device)     

     return index, (labels,variables)   

class datasets( IterableDataset ):
    def __init__( self, path, device, variables, labels):
        self.files=glob.glob(path)
    self.device=device
        self.variables=variables
        self.labels=labels
        self.restart()
        
    def restart(self):
        print("Re-starting iterator")
        # read first file and submit prefetching of the following
        self.file_index, self.current_data=_load_next(self.files,0, self.device)
        self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device)   
        
    def __iter__(self):
       while True:
            yield from zip(self.current_data[0], self.current_data[1])
            result=self.prefetch.result()
            if result is None: 
                self.executor.shutdown(wait=False)
                raise StopIteration
            else:
                self.file_index, self.current_data = result
                self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device)

字符串
逻辑运行良好,但是每个yield from调用需要几秒钟,这可能会引入不必要的延迟(延迟实际上比预取以下文件所需的时间更长)。是否有方法消除此延迟,也许异步运行yield from?当然欢迎其他想法!提前感谢!

a8jjtwal

a8jjtwal1#

感谢在讨论中收到的反馈,我设法进一步优化了代码-问题不在于yield调用,而在于torch.Tensor迭代器很慢,但可以异步构造。代码如下所示。
这比以前的版本快得多,但我并不完全满意:神经网络训练和预取似乎在争夺相同的资源(训练速度会减慢,直到预取完成)

def _load_next( file_list, index, device, labels, variables):
    if index >= len(file_list):
        return None
        
     thedata=pd.read_hdf(file_list[index], 'df')
     labels=torch.Tensor( thedata[labels].values).to(device).__iter__()
     variables=torch.Tensor( thedata[variables].values).to(device).__iter__()    

     return index, (labels,variables)   

class datasets( IterableDataset ):
    def __init__( self, path, device, variables, labels):
        self.files=glob.glob(path)
        self.device=device
        self.variables=variables
        self.labels=labels
        self.restart()
        
    def restart(self):
        print("Re-starting iterator")
        # read first file and submit prefetching of the following
        self.file_index, self.current_data=_load_next(self.files,0, self.device)
        self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device)   
        
    def __iter__(self):
       while True:
            yield from zip(self.current_data[0], self.current_data[1])
            result=self.prefetch.result()
            if result is None: 
                self.executor.shutdown(wait=False)
                self.restart()
                break
            else:
                self.file_index, self.current_data = result
                self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device)

字符串

相关问题