pytorch:将三个数据集与预定义和自定义数据集合并

mwngjboj  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(222)

我正在训练一个人工智能模型来识别手写的韩文字符以及英文字符和数字。这意味着我需要三个数据集自定义韩文字符数据集和其他数据集。
我有三个数据集,现在我正在合并三个数据集,但当我打印train_set路径时,它只显示MJSynth,这是错误的。

긴장_1227682.jpg is in my custom korean dataset not in MJSynth
    • 代码**
custom_train_set = RecognitionDataset(
            parts[0].joinpath("images"),
            parts[0].joinpath("labels.json"),
            img_transforms=Compose(
                [
                    T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
                    # Augmentations
                    T.RandomApply(T.ColorInversion(), 0.1),
                    ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
                ]
            ),
        )
        if len(parts) > 1:
            for subfolder in parts[1:]:
                custom_train_set.merge_dataset(
                    RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))
                )

        train_set = MJSynth(
            train=True,
            img_folder='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px',
            label_path='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt',
            img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
        )

        _train_set = SynthText(
            train=True,
            recognition_task=True,
            download=True,  # NOTE: download can take really long depending on your bandwidth
            img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
        )
        train_set.data.extend([(np_img, target) for np_img, target in _train_set.data])
        train_set.data.extend([(np_img, target) for np_img, target in custom_train_set.data])
    • 追溯**
Traceback (most recent call last):
  File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 485, in <module>
    main(args)
  File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 396, in main
    fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
  File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 118, in fit_one_epoch
    for images, targets in progress_bar(train_loader, parent=mb):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/fastprogress/fastprogress.py", line 50, in __iter__
    raise e
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/fastprogress/fastprogress.py", line 41, in __iter__
    for i,o in enumerate(self.gen):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
    return self._process_data(data)
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
    data.reraise()
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/media/cvpr/CM_22/doctr/doctr/datasets/datasets/base.py", line 48, in __getitem__
    img, target = self._read_sample(index)
  File "/media/cvpr/CM_22/doctr/doctr/datasets/datasets/pytorch.py", line 37, in _read_sample
    else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32)
  File "/media/cvpr/CM_22/doctr/doctr/io/image/pytorch.py", line 52, in read_img_as_tensor
    pil_img = Image.open(img_path, mode="r").convert("RGB")
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/PIL/Image.py", line 2912, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/긴장_1227682.jpg'
c86crjj0

c86crjj01#

您必须更改三个数据集的排列,因为您使用的是docTR库,与PyTorch中的正常ConcatenateDataset相比,在此库中合并数据集是不同的。
打印单个数据集的大小,以便检查数据集的实际长度和数据集的总体大小。

mjsynth_train = MJSynth(
            train=True,
            img_folder='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px',
            label_path='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt',
            img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
        )

        print("MJSynth dataset size is", len(mjsynth_train))

        synth_train = SynthText(
            recognition_task=True,
            download=True,  # NOTE: download can take really long depending on your bandwidth
            img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
        )

        print("SynthText dataset size is", len(synth_train))

        train_set.data.extend([(np_img, target) for np_img, target in mjsynth_train.data])
        train_set.data.extend([(np_img, target) for np_img, target in synth_train.data])

        print("Overall dataset size is", len(train_set))

相关问题