pytorch 如何将random_split与百分比拆分一起使用(输入长度之和不等于输入数据集的长度)

z9smfwbn  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(507)

我尝试使用torch.utils.data.random_split,如下所示:

  1. import torch
  2. from torch.utils.data import DataLoader, random_split
  3. list_dataset = [1,2,3,4,5,6,7,8,9,10]
  4. dataset = DataLoader(list_dataset, batch_size=1, shuffle=False)
  5. random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(123))

但是,当我尝试这样做时,我得到了错误raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
我看了看文档,似乎我应该能够传递小数,总和为1,但显然它不工作。
我还在谷歌上搜索了这个错误,最接近的结果是this issue
我做错了什么?

zd287kbt

zd287kbt1#

您可能使用的是旧版本的PyTorch,如Pytorch1.10,它不具有此功能。
要在旧版本中复制此功能,只需复制新版本的源代码即可:

  1. import math
  2. from torch import default_generator, randperm
  3. from torch._utils import _accumulate
  4. from torch.utils.data.dataset import Subset
  5. def random_split(dataset, lengths,
  6. generator=default_generator):
  7. r"""
  8. Randomly split a dataset into non-overlapping new datasets of given lengths.
  9. If a list of fractions that sum up to 1 is given,
  10. the lengths will be computed automatically as
  11. floor(frac * len(dataset)) for each fraction provided.
  12. After computing the lengths, if there are any remainders, 1 count will be
  13. distributed in round-robin fashion to the lengths
  14. until there are no remainders left.
  15. Optionally fix the generator for reproducible results, e.g.:
  16. >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
  17. >>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator(
  18. ... ).manual_seed(42))
  19. Args:
  20. dataset (Dataset): Dataset to be split
  21. lengths (sequence): lengths or fractions of splits to be produced
  22. generator (Generator): Generator used for the random permutation.
  23. """
  24. if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
  25. subset_lengths: List[int] = []
  26. for i, frac in enumerate(lengths):
  27. if frac < 0 or frac > 1:
  28. raise ValueError(f"Fraction at index {i} is not between 0 and 1")
  29. n_items_in_split = int(
  30. math.floor(len(dataset) * frac) # type: ignore[arg-type]
  31. )
  32. subset_lengths.append(n_items_in_split)
  33. remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
  34. # add 1 to all the lengths in round-robin fashion until the remainder is 0
  35. for i in range(remainder):
  36. idx_to_add_at = i % len(subset_lengths)
  37. subset_lengths[idx_to_add_at] += 1
  38. lengths = subset_lengths
  39. for i, length in enumerate(lengths):
  40. if length == 0:
  41. warnings.warn(f"Length of split at index {i} is 0. "
  42. f"This might result in an empty dataset.")
  43. # Cannot verify that dataset is Sized
  44. if sum(lengths) != len(dataset): # type: ignore[arg-type]
  45. raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
  46. indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
  47. return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
展开查看全部

相关问题