如何使用自定义coco风格的数据集重新训练torchvision的keypoint r-cnn?

t1rydlwq  于 2021-08-20  发布在  Java
关注(0)|答案(0)|浏览(523)

我已经使用coco注解器创建了一个自定义coco关键点样式的数据集,并希望对torchvision的关键点r-cnn进行再培训。我试图使用torchvision的cocodetection数据集类来加载数据,我不得不重写 _load_image 方法,因为我的数据集有子目录。然后,我尝试将数据集 Package 到数据加载器中,并出现以下错误:

>>> dl = DataLoader(coco, batch_size=4)
>>> feat, lbl = next(iter(dl))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 82, in default_collate
    raise RuntimeError('each element in list of batch should be of equal size')
RuntimeError: each element in list of batch should be of equal size

鉴于keypoint r-cnn需要一个[通道、高度、宽度]Tensor列表,尝试将数据集放入数据加载器是正确的做法吗?
此外,当我以可接受的格式获得数据时,我很难弄清楚我应该如何实际训练模型。我看过https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-作为固定特征提取器和https://github.com/pytorch/vision/tree/master/references/detection 我还是有点困惑。我可以得到一些关于如何在带有单个gpu的机器上训练模型的指导吗?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题