class MyDataset(Dataset):
'''
data_path: 数据集路径
img_size: 图片大小
train_lines: 图片名数组
'''
def __init__(self,data_path,img_size,train_lines):
super(MyDataset, self).__init__()
self.data_path = data_path
self.img_size = img_size
self.train_lines = train_lines
self.length = len(train_lines)
class MyDataset(Dataset):
'''
data_path: 数据集路径
img_size: 图片大小
train_lines: 图片名数组
'''
def __init__(self,data_path,img_size,train_lines):
super(MyDataset, self).__init__()
self.data_path = data_path
self.img_size = img_size
self.train_lines = train_lines
def __getitem__(self, index):
annotation_line = self.train_lines[index]
name = annotation_line.split()[0] # 获取图片名
image = Image.open(os.path.join(os.path.join(self.data_path,"dem"),name+".tif"))
label = Image.open(os.path.join(os.path.join(self.data_path, "label"), name + ".png"))
image = np.array(image)
label = np.array(label)
image = cv2.resize(image,(self.img_size,self.img_size))
label = cv2.resize(label,(self.img_size,self.img_size))
# image = image[np.newaxis,:]
print("images size: {}, label size: {}".format(image.shape,label.shape))
return image,label
如果不知道如何将文件夹中所有图片名称写入TXT中可以参考:python读取文件夹中的所有图片并将图片名逐行写入txt中:https://blog.csdn.net/weixin_43598687/article/details/125666776?spm=1001.2014.3001.5501
dataset_path = r"E:/workspace/PyCharmProject/dem_feature/dem/512"
# 打开数据集的txt, 逐行读取图片名
with open(os.path.join(dataset_path, "dem/train.txt"), "r") as f:
train_lines = f.readlines()
with open(os.path.join(dataset_path, "dem/val.txt"), "r") as f:
val_lines = f.readlines()
train_dataset = MyDataset(dataset_path, img_size=512,train_lines=train_lines)
train_dataloader = DataLoader(train_dataset,batch_size=8,shuffle=False)
for iteration,data in enumerate(train_dataloader):
imgs,labels = data
print(imgs,labels)
版权说明 : 本文为转载文章, 版权归原作者所有 版权申明
原文链接 : https://wang11.blog.csdn.net/article/details/125666337
内容来源于网络,如有侵权,请联系作者删除!