我正在训练一个人工智能模型来识别手写的韩文字符以及英文字符和数字。这意味着我需要三个数据集自定义韩文字符数据集和其他数据集。
我有三个数据集,现在我正在合并三个数据集,但当我打印train_set
路径时,它只显示MJSynth,这是错误的。
긴장_1227682.jpg is in my custom korean dataset not in MJSynth
- 代码**
custom_train_set = RecognitionDataset(
parts[0].joinpath("images"),
parts[0].joinpath("labels.json"),
img_transforms=Compose(
[
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]
),
)
if len(parts) > 1:
for subfolder in parts[1:]:
custom_train_set.merge_dataset(
RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))
)
train_set = MJSynth(
train=True,
img_folder='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px',
label_path='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt',
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
_train_set = SynthText(
train=True,
recognition_task=True,
download=True, # NOTE: download can take really long depending on your bandwidth
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
train_set.data.extend([(np_img, target) for np_img, target in _train_set.data])
train_set.data.extend([(np_img, target) for np_img, target in custom_train_set.data])
- 追溯**
Traceback (most recent call last):
File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 485, in <module>
main(args)
File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 396, in main
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 118, in fit_one_epoch
for images, targets in progress_bar(train_loader, parent=mb):
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/fastprogress/fastprogress.py", line 50, in __iter__
raise e
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/fastprogress/fastprogress.py", line 41, in __iter__
for i,o in enumerate(self.gen):
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
data = self._next_data()
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
return self._process_data(data)
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
data.reraise()
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise
raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/media/cvpr/CM_22/doctr/doctr/datasets/datasets/base.py", line 48, in __getitem__
img, target = self._read_sample(index)
File "/media/cvpr/CM_22/doctr/doctr/datasets/datasets/pytorch.py", line 37, in _read_sample
else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32)
File "/media/cvpr/CM_22/doctr/doctr/io/image/pytorch.py", line 52, in read_img_as_tensor
pil_img = Image.open(img_path, mode="r").convert("RGB")
File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/PIL/Image.py", line 2912, in open
fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/긴장_1227682.jpg'
1条答案
按热度按时间c86crjj01#
您必须更改三个数据集的排列,因为您使用的是docTR库,与PyTorch中的正常
ConcatenateDataset
相比,在此库中合并数据集是不同的。打印单个数据集的大小,以便检查数据集的实际长度和数据集的总体大小。