修改PyTorch DataLoader以避免在批处理中混合来自不同目录的文件

z0qdvdin  于 2024-01-09  发布在  其他
关注(0)|答案(1)|浏览(189)

我想将固定长度的图像序列加载到相同大小的批次中(例如序列长度=批次大小= 7)。
有多个目录,每个目录都有来自不同数量图像序列的图像。来自不同目录的序列彼此不相关。
用我现在的代码,我可以处理几个子目录,但是如果一个目录中没有足够的图像来填充一批,剩下的图像将从下一个目录中取出。我想避免这种情况。
相反,如果当前目录中没有足够的图像,则应丢弃一个批次,而该批次应仅使用下一个目录中的图像填充。这样,我希望避免在同一批次中混合不相关的图像序列。如果目录中没有足够的图像来创建一个批次,则应完全跳过该批次。
例如,序列长度/批量大小为7:

  • 目录A有15个图像→创建2个批次,每个批次有7个图像;忽略其余图像
  • 目录B有10个图像→创建1个批次,其中有7个图像;忽略其余图像
  • 目录C有3个图像→目录被完全跳过

我还在学习中,但我认为这可以用一个costum批量采样器来完成?不幸的是,我对此有些问题。也许有人可以帮我找到解决方案。
这是我现在的代码:

class MainDataset(Dataset):

    def __init__(self, img_dir, use_folder_name=False):
        self.gt_images = self._load_main_dataset(img_dir)
        self.dataset_len = len(self.gt_images)
        self.use_folder_name = use_folder_name

    def __len__(self):
        return self.dataset_len

    def __getitem__(self, idx):
        img_dir = self.gt_images[idx]
        img_name = self._get_name(img_dir)

        gt = self._load_img(img_dir)

        # Skip non-image files
        if gt is None:
            return None

        gt = torch.from_numpy(gt).permute(2, 0, 1)

        return gt, img_name

    def _get_name(self, img_dir):
        if self.use_folder_name:
            return img_dir.split(os.sep)[-2]
        else:
            return img_dir.split(os.sep)[-1].split('.')[0]

    def _load_main_dataset(self, img_dir):
        if not (os.path.isdir(img_dir)):
            return [img_dir]

        gt_images = []
        for root, dirs, files in os.walk(img_dir):
            for file in files:
                if not is_valid_file(file):
                    continue
                gt_images.append(os.path.join(root, file))

        gt_images.sort()

        return gt_images

    def _load_img(self, img_path):

        gt_image = io.imread(img_path)
        gt_image_bd = getBitDepth(gt_image)
        gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)

        return gt_image

def is_valid_file(file_name: str):

    # Check if the file has a valid image extension
    valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']

    for ext in valid_image_extensions: 
        if file_name.lower().endswith(ext):
            return True

    return False


sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)

字符串

1yjd4xko

1yjd4xko1#

虽然使用批处理采样器可能是一个好主意,可以有一个通用的自定义数据集,你可以不同的采样,我更喜欢一个简单的方法。
我会在init函数中构造一个数据结构,它已经包含了你要操作的所有图像序列。事实是,目前,你的Dataset类在撒谎,因为它说你的数据集的长度等于图像文件夹的数量。这是不正确的,因为它取决于文件夹中包含的图像数量。
目前,您的数据集一次只返回一个图像,而您需要序列。
你的问题中也缺少了一些关于数据集实际结构的信息。尽管如此,这里有一个Datatet类的建议:

class MainDataset(Dataset):

    def __init__(self, img_dir, use_folder_name=False, seq_len=7):
        self.seq_len = seq_len
        self.gt_images = self._load_main_dataset(img_dir)
        self.use_folder_name = use_folder_name

    def __len__(self):
        return len(self.gt_images)

    def __getitem__(self, idx):
        label, sequence = self.gt_images[idx]

        image_sequence = []
        for image_path in sequence:
            loaded_image = self._load_img(image_path)
            loaded_image = torch.from_numpy(loaded_image).permute(2, 0, 1)

            image_sequence.append(loaded_image)

        all_sequence = torch.stack(image_sequence, dim=0)

        # return a tensort of the sequence of images and the label 
        return all_sequence, label

    def _get_name(self, img_dir):
        if self.use_folder_name:
            return img_dir.split(os.sep)[-2]
        else:
            return img_dir.split(os.sep)[-1].split('.')[0]

    def _load_main_dataset(self, img_dir):

        # I don't really know why you don't throw an exception here.
        if not (os.path.isdir(img_dir)):
            return [img_dir]

        gt_images = []

        # Why using walk ? What is the structure of the dataset ?
        for root, dirs, files in os.walk(img_dir):

            # This variable accumulates the images in the sequence
            image_sequence = []

            for file in files:
                if not is_valid_file(file):
                    continue

                img_path = os.path.join(root, file)
                image_sequence.append(img_path)

                if len(image_sequence) == self.seq_len:
                    sorted_sequence = image_sequence.sort()
                    label = self._get_name(sorted_sequence)

                    gt_images.append((label,sorted_sequence))
                    image_sequence = []

        # Now gt_images is a list of tuples (label, sequence)
        return gt_images

    def _load_img(self, img_path):

        gt_image = io.imread(img_path)
        gt_image_bd = getBitDepth(gt_image)
        gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)

        return gt_image

def is_valid_file(file_name: str):

    # Check if the file has a valid image extension
    valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']

    for ext in valid_image_extensions: 
        if file_name.lower().endswith(ext):
            return True

    return False


sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)

字符串

相关问题