使用PyTorch在每次迭代中仅对一个类的批次进行高效采样

g9icjywg  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(176)

bounty已结束。回答此问题可获得+50的声望奖励。奖励宽限期将在22小时后结束。Thoth正在寻找标准答案:我正在寻找一个有效的解决方案,在我的问题与教学描述PyTorch初学者。

我想在ImageNet数据集(1000个类)上训练一个分类器,并且我需要每个批次包含来自同一类的64个图像和来自不同类的连续批次。

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os

class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(num_classes)]
        for i, (data, class_label) in enumerate(data):
            # create a list of lists, where every sublist containts the indices of
            # the samples that belong to the class_label
            self.indices[class_label].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

class BatchSampler:
    def __init__(self, classes, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.min_class_size = min([len(x) for x in classes])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.classes)))
        random.shuffle(self.class_range)

        assert batch_size < self.min_class_size, 'batch_size should be at least {}'.format(self.min_class_size)

    def __iter__(self):
        batches = []
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            batches.append(np.random.choice(self.classes[batch_class], self.batch_size))
        return iter(batches)

def main():
    # Code about
    _train_dataset = DS(train_dataset, train_dataset.num_classes)
    _batch_sampler = BatchSampler(_train_dataset.classes(), batch_size=args.batch_size)
    _train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
    labels = []
    for i, (inputs, _labels) in enumerate(_train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    # Samplers are initialized to None and train_sampler will be replaced
    train_sampler, val_sampler = None, None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    main()

它打印:Length of traversed unique labels: 100 .
但是,在for循环中创建self.indices需要花费大量时间,有没有更有效的方法来构造这个采样器?
EDIT:产量实施

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
from tqdm import tqdm
import os.path

class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data
        self.data_len = len(data)

        indices = [[] for _ in range(num_classes)]

        for i, (_, class_label) in tqdm(enumerate(data), total=len(data), miniters=1,
                                        desc='Building class indices dataset..'):
            indices[class_label].append(i)

        self.indices = indices

    def per_class_sample_indices(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.data_len

class BatchSampler:
    def __init__(self, per_class_sample_indices, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.per_class_sample_indices = per_class_sample_indices
        self.n_batches = sum([len(x) for x in per_class_sample_indices]) // batch_size
        self.min_class_size = min([len(x) for x in per_class_sample_indices])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.per_class_sample_indices)))
        random.shuffle(self.class_range)

    def __iter__(self):
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            if self.batch_size <= len(self.per_class_sample_indices[batch_class]):
                batch = np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size)
                # batches.append(np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size))
            else:
                batch = self.per_class_sample_indices[batch_class]
            yield batch

    def n_batches(self):
        return self.n_batches

def main():
    file_path = 'a_file_path'
    file_name = 'per_class_sample_indices.pt'
    if not os.path.exists(os.path.join(file_path, file_name)):
        print('File: {} does not exists. Create it.'.format(file_name))
        per_class_sample_indices = DS(train_dataset, num_classes).per_class_sample_indices()
        torch.save(per_class_sample_indices, os.path.join(file_path, file_name))
    else:
        per_class_sample_indices = torch.load(os.path.join(file_path, file_name))
        print('File: {} exists. Do not create it.'.format(file_name))

    batch_sampler = BatchSampler(per_class_sample_indices,
                                 batch_size=args.batch_size)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        # batch_size=args.batch_size,
        # shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        # sampler=train_sampler,
        batch_sampler=batch_sampler
    )

    # We do not use sampler for the validation
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True, sampler=None)

    labels = []
    for i, (inputs, _labels) in enumerate(train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        num_classes = len(train_dataset.classes)

    main()

类似的post,但在TensorFlow中,可在此处找到

ewm0tg9j

ewm0tg9j1#

你的代码看起来很好。这里的问题不是采样器,而是你需要执行的预处理步骤,以便按类对示例索引进行排序。由于这总是相同的排序,我建议您缓存此信息(self.indices中包含的数据),这样就不必在每次加载数据集时重新构建它。您可以使用numpy.savetorch.save来执行此操作。

carvr3hs

carvr3hs2#

您应该为DataLoader编写自己的batch_sampler类。

相关问题