PyTorch:如何使用 Torch 视觉。转换。8月混合与 Torch 。float 32?
我正在尝试使用torchvision.transforms.AugMix在影像数据集中应用数据扩充,但出现以下错误:TypeError:只支援torch.uint8影像Tensor,但找到torch.float32。我尝试将它转换成int,但发生另一个错误。
我尝试使用AugMix函数的代码:
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)), # resize to 224*224
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # normalization
torchvision.transforms.AugMix()
]
)
to_tensor = torchvision.transforms.ToTensor()
Image.MAX_IMAGE_PIXELS = None
class BreastDataset(torch.utils.data.Dataset):
def __init__(self, json_path, data_dir_path='./dataset', clinical_data_path=None, is_preloading=True):
self.data_dir_path = data_dir_path
self.is_preloading = is_preloading
with open(json_path) as f:
print(f"load data from {json_path}")
self.json_data = json.load(f)
def __len__(self):
return len(self.json_data)
def __getitem__(self, index):
label = int(self.json_data[index]["label"])
patient_id = self.json_data[index]["id"]
patch_paths = self.json_data[index]["patch_paths"]
data = {}
if self.is_preloading:
data["bag_tensor"] = self.bag_tensor_list[index]
else:
data["bag_tensor"] = self.load_bag_tensor([os.path.join(self.data_dir_path, p_path) for p_path in patch_paths])
data["label"] = label
data["patient_id"] = patient_id
data["patch_paths"] = patch_paths
return data
def load_bag_tensor(self, patch_paths):
"""Load a bag data as tensor with shape [N, C, H, W]"""
patch_tensor_list = []
for p_path in patch_paths:
patch = Image.open(p_path).convert("RGB")
patch_tensor = transform(patch) # [C, H, W]
patch_tensor = torch.unsqueeze(patch_tensor, dim=0) # [1, C, H, W]
patch_tensor_list.append(patch_tensor)
bag_tensor = torch.cat(patch_tensor_list, dim=0) # [N, C, H, W]
return bag_tensor
任何帮助都是感激不尽的!提前谢谢你!
2条答案
按热度按时间dy2hfwbg1#
对我来说,首先应用
AugMix
,然后ToTensor()
工作yzuktlbb2#
torchvision.transforms.AugMix
以uint8
拍摄图像。这意味着每个像素都是1(灰度)或3(rgb)0到255之间的数字,这是图像的经典格式。torch.Tensor.type(torch.float32)
将uint8
Tensor转换为float32
,但这不太可能是应用于图像的单一变换。float32
图像通常被归一化到[-1,1]或[0,1]范围内。常用的方法是:当你知道你是在什么情况下,你可以重铸为
uint8
: