PyTorch:如何将相同的随机变换应用于多个图像?

qcuzuvrc  于 2023-10-20  发布在  其他
关注(0)|答案(5)|浏览(176)

我正在为一个包含许多图像对的数据集编写一个简单的转换。作为数据增强,我想对每一对应用一些随机变换,但该对中的图像应该以相同的方式进行变换。例如,给定一对两个图像AB,如果A被水平翻转,则B必须被水平翻转为A。然后下一对CD应该与AB不同地变换,但是CD以相同的方式变换。我在下面的方法中尝试

import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

transform = transforms.RandomChoice(
    [transforms.RandomHorizontalFlip(), 
     transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))

random.seed(1)
display(transform(img_c))
display(transform(img_d))

然而,上面的代码没有选择相同的转换,正如我所测试的,它取决于transform被调用的次数。
有没有办法强制transforms.RandomChoice在指定时使用相同的转换?

tyg4sfes

tyg4sfes1#

我意识到OP要求使用torchvision的解决方案,我认为@Ivan的answer很好地解决了这个问题。
然而,对于那些没有绑定到特定增强库的人,我想指出的是,Albumentations似乎在native fashion中很好地处理了这类情况,允许用户将多个源图像,框等传递到同一个转换中。返回的结构为dict

import albumentations as A

transform = A.Compose(
    transforms=[
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5)],
    additional_targets={'image0': 'image', 'image1': 'image'}
)
transformed = transform(image=image, image0=image0, image1=image1)

现在您可以访问transformed['image0']transformed['image1']等,所有这些都将应用随机参数

ocebsuys

ocebsuys2#

为输入和目标引用随机变换?我想这可能是最干净的方法了。在应用任何转换之前保存随机状态,并在每次后续调用时恢复它

t = transforms.RandomRotation(degrees=360)
state = torch.get_rng_state()
x = t(x)
torch.set_rng_state(state)
y = t(y)
0h4hbjxa

0h4hbjxa3#

简单地说,将PyTorch中的随机化部分放入if语句中。下面的代码使用vflip。类似地,对于水平或其他变换。

import random
import torchvision.transforms.functional as TF

if random.random() > 0.5:
    image = TF.vflip(image)
    mask  = TF.vflip(mask)

这个问题已经在PyTorch forum中讨论过了。在官方GitHub存储库page上讨论了几种解决方案的优缺点。PyTorch的维护人员已经提出了这种简单的方法。
不要使用torchvision.transforms.RandomVerticalFlip(p=1)。使用torchvision.transforms.functional.vflip
函数转换给予对转换管道的细粒度控制。与上面的变换相反,函数变换不包含用于其参数的随机数生成器。这意味着您必须指定/生成所有参数,但您可以重用函数转换。

3ks5zfa0

3ks5zfa04#

我不知道一个函数来修复随机输出。也许可以尝试不同的逻辑,比如自己创建随机化,以便能够重用相同的转换。逻辑:

  • 生成随机数
  • 基于该数字对两个图像应用变换
  • 生成另一个随机数
  • 对另外两个图像做同样的操作,试试这个:
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

if random.random() > 0.5:
        image_a_flipped = transforms.functional_pil.vflip(img_a)
        image_b_flipped = transforms.functional_pil.vflip(img_b)
else:
    image_a_flipped = transforms.functional_pil.hflip(img_a)
    image_b_flipped = transforms.functional_pil.hflip(img_b)

if random.random() > 0.5:
        image_c_flipped = transforms.functional_pil.vflip(img_c)
        image_d_flipped = transforms.functional_pil.vflip(img_d)
else:
    image_c_flipped = transforms.functional_pil.hflip(img_c)
    image_d_flipped = transforms.functional_pil.hflip(img_d)
    
display(image_a_flipped)
display(image_b_flipped)

display(image_c_flipped)
display(image_d_flipped)
vcirk6k6

vcirk6k65#

通常,解决方法是在第一个图像上应用变换,检索该变换的参数,然后在其余图像上应用具有这些参数的确定性变换。但是,这里RandomChoice没有提供API来获取所应用的转换的参数,因为它涉及可变数量的转换。在这些情况下,我通常实现对原始函数的覆盖。
看看torchvision的实现,它就像这样简单:

class RandomChoice(RandomTransforms):
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)

这里有两个可能的解决方案。
1.您可以从__init__上的转换列表中采样,而不是从__call__上采样:

import random
import torchvision.transforms as T

class RandomChoice(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.t = random.choice(self.transforms)

    def __call__(self, img):
        return self.t(img)

因此,您可以:

transform = RandomChoice([
     T.RandomHorizontalFlip(), 
     T.RandomVerticalFlip()
])
display(transform(img_a)) # both img_a and img_b will
display(transform(img_b)) # have the same transform

transform = RandomChoice([
    T.RandomHorizontalFlip(), 
    T.RandomVerticalFlip()
])
display(transform(img_c)) # both img_c and img_d will
display(transform(img_d)) # have the same transform

1.或者更好的是,批量转换图像:

import random
import torchvision.transforms as T

class RandomChoice(torch.nn.Module):
    def __init__(self, transforms):
       super().__init__()
       self.transforms = transforms

    def __call__(self, imgs):
        t = random.choice(self.transforms)
        return [t(img) for img in imgs]

它允许做:

transform = RandomChoice([
     T.RandomHorizontalFlip(), 
     T.RandomVerticalFlip()
])

img_at, img_bt = transform([img_a, img_b])
display(img_at) # both img_a and img_b will
display(img_bt) # have the same transform

img_ct, img_dt = transform([img_c, img_d])
display(img_ct) # both img_c and img_d will
display(img_dt) # have the same transform

相关问题