我尝试使用torch.utils.data.random_split
,如下所示:
import torch
from torch.utils.data import DataLoader, random_split
list_dataset = [1,2,3,4,5,6,7,8,9,10]
dataset = DataLoader(list_dataset, batch_size=1, shuffle=False)
random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(123))
但是,当我尝试这样做时,我得到了错误raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
我看了看文档,似乎我应该能够传递小数,总和为1,但显然它不工作。
我还在谷歌上搜索了这个错误,最接近的结果是this issue。
我做错了什么?
1条答案
按热度按时间zd287kbt1#
您可能使用的是旧版本的PyTorch,如Pytorch1.10,它不具有此功能。
要在旧版本中复制此功能,只需复制新版本的源代码即可: