pytorch torch.utils.data.random_split()未拆分数据集

klr1opcd  于 2024-01-09  发布在  其他
关注(0)|答案(1)|浏览(202)

我使用ImageFolder从一个目录加载数据:

full_dataset = ImageFolder('some_dir', transform=transform)

字符串
当我打印它的长度时,它给出:32854。现在我想使用torch.utils.data.random_split()ImageFolder返回的Dataset拆分为训练和测试数据集。我尝试传递分数[0.8, 0.2],长度为[len(full_dataset) - 100, 100]

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [len(full_dataset) - 100, 100])


但是当我使用len(train_dataset.dataset.imgs)len(test_dataset.dataset.imgs)打印它们的长度时,它们显示的值与full_dataset相同。
为什么我的分身不行?

33qvvth1

33qvvth11#

在执行train_dataset.dataset(和test_dataset.dataset类似)时,您引用的是原始数据集(在本例中为full_dataset)。因此,train_dataset.dataset(和test_dataset.dataset)上的imgs属性将给予属于原始数据集的所有图像,而不是每个分割的图像。
由于random_split返回的Subset对象具有__len__方法(Subset在技术上是抽象类Dataset的子类),您可以直接在它们上使用len来获取每个拆分/子集的长度:

len(train_dataset)
len(test_dataset)

字符串

相关问题