我使用ImageFolder
从一个目录加载数据:
full_dataset = ImageFolder('some_dir', transform=transform)
字符串
当我打印它的长度时,它给出:32854。现在我想使用torch.utils.data.random_split()
将ImageFolder
返回的Dataset
拆分为训练和测试数据集。我尝试传递分数[0.8, 0.2]
,长度为[len(full_dataset) - 100, 100]
。
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [len(full_dataset) - 100, 100])
型
但是当我使用len(train_dataset.dataset.imgs)
和len(test_dataset.dataset.imgs)
打印它们的长度时,它们显示的值与full_dataset
相同。
为什么我的分身不行?
1条答案
按热度按时间33qvvth11#
在执行
train_dataset.dataset
(和test_dataset.dataset
类似)时,您引用的是原始数据集(在本例中为full_dataset
)。因此,train_dataset.dataset
(和test_dataset.dataset
)上的imgs
属性将给予属于原始数据集的所有图像,而不是每个分割的图像。由于
random_split
返回的Subset
对象具有__len__
方法(Subset
在技术上是抽象类Dataset
的子类),您可以直接在它们上使用len
来获取每个拆分/子集的长度:字符串