我已经使用coco注解器创建了一个自定义coco关键点样式的数据集,并希望对torchvision的关键点r-cnn进行再培训。我试图使用torchvision的cocodetection数据集类来加载数据,我不得不重写 _load_image
方法,因为我的数据集有子目录。然后,我尝试将数据集 Package 到数据加载器中,并出现以下错误:
>>> dl = DataLoader(coco, batch_size=4)
>>> feat, lbl = next(iter(dl))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 82, in default_collate
raise RuntimeError('each element in list of batch should be of equal size')
RuntimeError: each element in list of batch should be of equal size
鉴于keypoint r-cnn需要一个[通道、高度、宽度]Tensor列表,尝试将数据集放入数据加载器是正确的做法吗?
此外,当我以可接受的格式获得数据时,我很难弄清楚我应该如何实际训练模型。我看过https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-作为固定特征提取器和https://github.com/pytorch/vision/tree/master/references/detection 我还是有点困惑。我可以得到一些关于如何在带有单个gpu的机器上训练模型的指导吗?
暂无答案!
目前还没有任何答案,快来回答吧!