pytorch 运行U-Net脚本时出现Assert错误

e5njpo68  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(212)

这是这个问题的延续。虽然我解决了这些问题,但我仍然得到了另一个问题。有没有人能在这方面帮助我?
看起来预测遮罩和实际遮罩的大小不同?
输出代码如下:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_18/459131192.py in <module>
     25             with torch.set_grad_enabled(phase == "train"):
     26                 y_pred = unet(x)
---> 27                 loss = dsc_loss(y_pred, y_true)
     28                 running_loss += loss.item()
     29 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input,**kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input,**kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_18/3969884729.py in forward(self, y_pred, y_true)
      6 
      7     def forward(self, y_pred, y_true):
----> 8         assert y_pred.size() == y_true.size()
      9         y_pred = y_pred[:, 0].contiguous().view(-1)
     10         y_true = y_true[:, 0].contiguous().view(-1)

AssertionError:

下面是U-Net的模型,请大家看一下。
unet_网络.py:


# Unet

# https://github.com/mateuszbuda/brain-segmentation-pytorch

from collections import OrderedDict

import torch
import torch.nn as nn

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=8):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

感谢并致以最诚挚的问候
迈克尔·施罗特

kxxlusnw

kxxlusnw1#

您的错误源于预测(pred=torch.Size([5, 1, 512, 512]))和目标(y_true=torch.Size([5, 3, 512, 512]))之间的通道数差异。
对于一个有3个通道的目标,你需要你的pred也有3个通道,也就是说,你需要配置你的UNetout_channels=3而不是默认的1。

相关问题