pytorch创建自己的Dataset加载数据集

x33g5p2x  于2022-07-10 转载在 其他  
字(1.8k)|赞(0)|评价(0)|浏览(608)

创建一个类并继承torch.utils.data.dataset.Datase类

  1. class MyDataset(Dataset):
  2. '''
  3. data_path: 数据集路径
  4. img_size: 图片大小
  5. train_lines: 图片名数组
  6. '''
  7. def __init__(self,data_path,img_size,train_lines):
  8. super(MyDataset, self).__init__()
  9. self.data_path = data_path
  10. self.img_size = img_size
  11. self.train_lines = train_lines
  12. self.length = len(train_lines)

创建__getitem__方法

  1. class MyDataset(Dataset):
  2. '''
  3. data_path: 数据集路径
  4. img_size: 图片大小
  5. train_lines: 图片名数组
  6. '''
  7. def __init__(self,data_path,img_size,train_lines):
  8. super(MyDataset, self).__init__()
  9. self.data_path = data_path
  10. self.img_size = img_size
  11. self.train_lines = train_lines
  12. def __getitem__(self, index):
  13. annotation_line = self.train_lines[index]
  14. name = annotation_line.split()[0] # 获取图片名
  15. image = Image.open(os.path.join(os.path.join(self.data_path,"dem"),name+".tif"))
  16. label = Image.open(os.path.join(os.path.join(self.data_path, "label"), name + ".png"))
  17. image = np.array(image)
  18. label = np.array(label)
  19. image = cv2.resize(image,(self.img_size,self.img_size))
  20. label = cv2.resize(label,(self.img_size,self.img_size))
  21. # image = image[np.newaxis,:]
  22. print("images size: {}, label size: {}".format(image.shape,label.shape))
  23. return image,label

加载数据集

如果不知道如何将文件夹中所有图片名称写入TXT中可以参考:python读取文件夹中的所有图片并将图片名逐行写入txt中:https://blog.csdn.net/weixin_43598687/article/details/125666776?spm=1001.2014.3001.5501

  1. dataset_path = r"E:/workspace/PyCharmProject/dem_feature/dem/512"
  2. # 打开数据集的txt, 逐行读取图片名
  3. with open(os.path.join(dataset_path, "dem/train.txt"), "r") as f:
  4. train_lines = f.readlines()
  5. with open(os.path.join(dataset_path, "dem/val.txt"), "r") as f:
  6. val_lines = f.readlines()
  7. train_dataset = MyDataset(dataset_path, img_size=512,train_lines=train_lines)
  8. train_dataloader = DataLoader(train_dataset,batch_size=8,shuffle=False)
  9. for iteration,data in enumerate(train_dataloader):
  10. imgs,labels = data
  11. print(imgs,labels)

相关文章