我有一个问题与DataLoader形式Pytorch,因为它非常慢。
我做了一个测试来证明这一点,这里是代码:
data = np.load('slices.npy')
data = np.reshape(data, (-1, 1225))
data = torch.FloatTensor(data).to('cuda')
print(data.shape)
# ==> torch.Size([273468, 1225])
class UnlabeledTensorDataset(TensorDataset):
def __init__(self, data_tensor):
self.data_tensor = data_tensor
self.samples = data_tensor.shape[0]
def __getitem__(self, index):
return self.data_tensor[index]
def __len__(self):
return self.samples
test_set = UnlabeledTensorDataset(data)
test_loader = DataLoader(test_set, batch_size=data.shape[0])
start = datetime.datetime.now()
with torch.no_grad():
for batch in test_loader:
print(batch.shape) # ==> torch.Size([273468, 1225])
y_pred = model(batch)
loss = torch.sqrt(criterion(y_pred, batch))
avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 1527.57 (milliseconds) !!!!!!!!!!!!!!!!!!!!!!!!
start = datetime.datetime.now()
with torch.no_grad():
print(data.shape) # ==> torch.Size([273468, 1225])
y_pred = model(data)
loss = torch.sqrt(criterion(y_pred, data))
avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 2.0 (milliseconds) !!!!!!!!!!!!!!!!!!!!!!!!
我想使用数据加载器,但我想要一种方法来解决缓慢的问题,有人知道为什么会发生这种情况吗?
4条答案
按热度按时间mqxuamgl1#
时间差对我来说似乎是合乎逻辑的:
test_loader
并进行1225
推断。cngwdvgl2#
问题是DataLoader通过应用
__getitem__(self, index)
函数273468次(批量大小)来逐个获取项。我的解决方案是放弃数据记录器
8ehkhllq3#
你可以通过实现getitems函数来解决这个问题
该函数接受一个索引列表作为参数。在您的情况下,它将一次性返回整个批次,这将大大加快速度。
而且,没有必要这么早就把它转换成一个 Torch Tensor。这应该由DataLoader中的collate函数完成。我已经实现了collate_fn作为一个例子。在您的特定情况下,它也可能与默认实现一起工作,但只要您想返回例如。输入和标签,如果您正在实现getitems,则需要您自己的输入和标签
`
`
pw136qt24#
如果你在单独的numpy数组中有输入和标签,并将它们作为元组返回,你会这样做: