Pytorch DataLoder非常慢

syqv5f0l  于 2023-10-20  发布在  其他
关注(0)|答案(4)|浏览(66)

我有一个问题与DataLoader形式Pytorch,因为它非常慢。
我做了一个测试来证明这一点,这里是代码:

data = np.load('slices.npy')
data = np.reshape(data, (-1, 1225))
data = torch.FloatTensor(data).to('cuda')
print(data.shape)
# ==> torch.Size([273468, 1225])

class UnlabeledTensorDataset(TensorDataset):
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor
        self.samples = data_tensor.shape[0]

    def __getitem__(self, index):
        return self.data_tensor[index]
    
    def __len__(self):
        return self.samples

test_set = UnlabeledTensorDataset(data)
test_loader = DataLoader(test_set, batch_size=data.shape[0])

start = datetime.datetime.now()
with torch.no_grad():
    for batch in test_loader:
        print(batch.shape)     # ==> torch.Size([273468, 1225])
        y_pred = model(batch)
        loss = torch.sqrt(criterion(y_pred, batch))
        avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 1527.57  (milliseconds)   !!!!!!!!!!!!!!!!!!!!!!!!

start = datetime.datetime.now()
with torch.no_grad():
    print(data.shape)     # ==> torch.Size([273468, 1225])
    y_pred = model(data)
    loss = torch.sqrt(criterion(y_pred, data))
    avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 2.0     (milliseconds)    !!!!!!!!!!!!!!!!!!!!!!!!

我想使用数据加载器,但我想要一种方法来解决缓慢的问题,有人知道为什么会发生这种情况吗?

mqxuamgl

mqxuamgl1#

时间差对我来说似乎是合乎逻辑的:

  • 在一端,循环test_loader并进行1225推断。
  • 另一方面,你在做一个单一的推理。
cngwdvgl

cngwdvgl2#

问题是DataLoader通过应用__getitem__(self, index)函数273468次(批量大小)来逐个获取项。
我的解决方案是放弃数据记录器

8ehkhllq

8ehkhllq3#

你可以通过实现getitems函数来解决这个问题
该函数接受一个索引列表作为参数。在您的情况下,它将一次性返回整个批次,这将大大加快速度。
而且,没有必要这么早就把它转换成一个 Torch Tensor。这应该由DataLoader中的collate函数完成。我已经实现了collate_fn作为一个例子。在您的特定情况下,它也可能与默认实现一起工作,但只要您想返回例如。输入和标签,如果您正在实现getitems,则需要您自己的输入和标签
`

data = np.load('slices.npy')
data = np.reshape(data, (-1, 1225))
print(data.shape)
# ==> torch.Size([273468, 1225])

class UnlabeledTensorDataset(Dataset):
    def __init__(self, data_np_array):
        self.data_np_array = data_np_array
        self.samples = data_np_array.shape[0]

    def __getitem__(self, index):
        return self.data_np_array[index]
    
    def __getitems__(self, index_list):
        return self.__getitem__(self, index_list)
    
    def __len__(self):
        return self.samples

def collate_fn(data):
    return torch.from_numpy(data)

test_set = UnlabeledTensorDataset(data)
test_loader = DataLoader(test_set, batch_size=data.shape[0], collate_fn=collate_fn)

start = datetime.datetime.now()
with torch.no_grad():
    for batch in test_loader:
        batch.to('cuda')
        print(batch.shape)     # ==> torch.Size([273468, 1225])
        y_pred = model(batch)
        loss = torch.sqrt(criterion(y_pred, batch))
        avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 1527.57  (milliseconds)   !!!!!!!!!!!!!!!!!!!!!!!!

start = datetime.datetime.now()
with torch.no_grad():
    print(data.shape)     # ==> torch.Size([273468, 1225])
    y_pred = model(data)
    loss = torch.sqrt(criterion(y_pred, data))
    avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))

`

pw136qt2

pw136qt24#

如果你在单独的numpy数组中有输入和标签,并将它们作为元组返回,你会这样做:

def collate_fn(data):
    inputs, labels = data
    return torch.from_numpy(inputs), torch.from_numpy(labels)

相关问题