pytorch 如何摆脱棋盘伪影

vxqlmq5t  于 2022-12-23  发布在  其他
关注(0)|答案(4)|浏览(307)

我正在使用一个全卷积自动编码器来给白色图像上色,但是,输出有一个checkerboard pattern,我想摆脱它。到目前为止,我看到的棋盘状伪像总是比我的小得多,摆脱它们的通常方法是用双线性上采样替换所有的去池操作(我已经被告知)。
但是我不能简单地替换去池操作,因为我处理不同大小的图像,因此需要去池操作,否则输出Tensor可能与原始Tensor大小不同。

顶级域名:

我怎样才能在不替换去池操作的情况下消除这些棋盘状工件呢?

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.leaky_reLU = nn.LeakyReLU(0.2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2, padding=1)
        self.softmax = nn.Softmax2d()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv9 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv10 = nn.ConvTranspose2d(in_channels=64, out_channels=2, kernel_size=3, stride=1, padding=1)

    def forward(self, x):

        # encoder
        x = self.conv1(x)
        x = self.leaky_reLU(x)
        size1 = x.size()
        x, indices1 = self.pool(x)

        x = self.conv2(x)
        x = self.leaky_reLU(x)
        size2 = x.size()
        x, indices2 = self.pool(x)

        x = self.conv3(x)
        x = self.leaky_reLU(x)
        size3 = x.size()
        x, indices3 = self.pool(x)

        x = self.conv4(x)
        x = self.leaky_reLU(x)
        size4 = x.size()
        x, indices4 = self.pool(x)

        ######################
        x = self.conv5(x)
        x = self.leaky_reLU(x)

        x = self.conv6(x)
        x = self.leaky_reLU(x)
        ######################

        # decoder
        x = self.unpool(x, indices4, output_size=size4)
        x = self.conv7(x)
        x = self.leaky_reLU(x)

        x = self.unpool(x, indices3, output_size=size3)
        x = self.conv8(x)
        x = self.leaky_reLU(x)

        x = self.unpool(x, indices2, output_size=size2)
        x = self.conv9(x)
        x = self.leaky_reLU(x)

        x = self.unpool(x, indices1, output_size=size1)
        x = self.conv10(x)
        x = self.softmax(x)

        return x

eqfvzcg8

eqfvzcg81#

除了使用 upconv 层(如nn.ConvTranspose2d)之外,您还可以在解码器部分使用插值返回到初始格式(如torch.nn.functional.interpolate)。这将防止出现棋盘状伪像。
如果希望解码器中的权重可学习,则还应在每次插值后使用 conv 层,如nn.Conv2d

n3schb8v

n3schb8v2#

跳接是编解码器结构中常用的一种,它通过从编码器的浅层传递外观信息来帮助产生准确的结果(鉴别器)到解码器的相应较深层Unet是目前应用最广泛的Encoder-Decoder类型的体系结构,Linknet也很流行,它与Unet的不同之处在于编码层和解码层的外观信息的融合方式。在Unet的情况下,传入的特征(来自编码器)在相应的解码器层中连接。另一方面,Linknet执行加法,这就是为什么Linknet在单次正向传递中需要较少数量的操作,并且比Unet快得多。
解码器中的每个卷积块可能如下所示:

另外,我附上了一个描述Unet和LinkNet体系结构的图。希望使用跳过连接会有所帮助。

gojuced7

gojuced73#

这个模式是由于反卷积(nn.ConvTranspose2d)而产生的。article详细解释了它。
您可以尝试“上采样”作为替代方案。这不会提供棋盘图案。
工作原理如下:

import torch 
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
m = torch.nn.Upsample(scale_factor=2, mode='nearest')
m(input)

然而,你将不能用上采样学到任何东西。它只是一个变换。所以它是一种交易。网上有很多论文介绍如何针对不同的问题处理棋盘模式。
这个想法是训练你的网络,使棋盘模式消失。

z9smfwbn

z9smfwbn4#

作为Kaushik Roy has specified,跳过连接是要走的路!

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.leaky_reLU = nn.LeakyReLU(0.2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2, padding=1)
        self.softmax = nn.Softmax2d()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv9 = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=3, stride=1, padding=1)

    def forward(self, x):

        # encoder
        x = self.conv1(x)
        out1 = self.leaky_reLU(x)
        x = out1
        size1 = x.size()
        x, indices1 = self.pool(x)

        x = self.conv2(x)
        out2 = self.leaky_reLU(x)
        x = out2
        size2 = x.size()
        x, indices2 = self.pool(x)

        x = self.conv3(x)
        out3 = self.leaky_reLU(x)
        x = out3
        size3 = x.size()
        x, indices3 = self.pool(x)

        x = self.conv4(x)
        out4 = self.leaky_reLU(x)
        x = out4
        size4 = x.size()
        x, indices4 = self.pool(x)

        ######################
        x = self.conv5(x)
        x = self.leaky_reLU(x)

        x = self.conv6(x)
        x = self.leaky_reLU(x)
        ######################

        # decoder
        x = self.unpool(x, indices4, output_size=size4)
        x = self.conv7(torch.cat((x, out4), 1))
        x = self.leaky_reLU(x)

        x = self.unpool(x, indices3, output_size=size3)
        x = self.conv8(torch.cat((x, out3), 1))
        x = self.leaky_reLU(x)

        x = self.unpool(x, indices2, output_size=size2)
        x = self.conv9(torch.cat((x, out2), 1))
        x = self.leaky_reLU(x)

        x = self.unpool(x, indices1, output_size=size1)
        x = self.conv10(torch.cat((x, out1), 1))
        x = self.softmax(x)

        return x

此答案以edit的形式发布在CC BY-SA 4.0下的问题“如何通过OP Stefan消除棋盘状伪影”中。

相关问题