使用数据加载器pytorch时标签错误

tkqqtvp1  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(129)

下面的代码:

def get_relevant_indicies(dataset):
    """Returns the indicies of the classes in the dataset"""
    indicies = []
    for i in range(len(dataset)):
        idx = dataset[i][1]
        indicies.append(idx)
    return indicies

def get_data(batch_size, folder):
    """Takes a batch_size and the name of the folder (name of folder most likely called dataset)
    Example:
    get_data(1, "~/aps360-proj/dataset")

    """
    classes = ("testing1", "testing2", "testing3")

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    #Load images
    trainset = torchvision.datasets.ImageFolder(folder, transform=transform)    
    #Get indicies of images
    relevant_train_indicies = get_relevant_indicies(trainset)

    np.random.seed(1)
    np.random.shuffle(relevant_train_indicies)
    random_sampler = SubsetRandomSampler(relevant_train_indicies)
    for i in random_sampler:
        print(i)
    train_loader = torch.utils.data.DataLoader(trainset, sampler=random_sampler)
    for images, labels in train_loader:
        print(labels)

两个print语句的输出不同,我不知道为什么。对于random_sampler,它输出0-〉2,这是预期的,因为有三个文件夹,但在将其传递到dataloader之后,它只输出0-〉1
第一次

qij5mzcb

qij5mzcb1#

很难说你想用get_relevant_indicies函数实现什么。
get_relevant_indicies函数返回数据集中每个样本的 * 标签 * 列表。这就是dataset[i][1]为ImageFolder数据集返回的内容-图像的目标标签(图像本身将在dataset[i][0]中)。
然后,您使用这些标签(0、1或2)作为数据索引,这完全不是您想要的。

相关问题