pytorch 骰子损失仅在概率在分母处平方时起作用

tkqqtvp1  于 2023-03-12  发布在  其他
关注(0)|答案(2)|浏览(116)

我遇到了一个有趣的和令人沮丧的问题与骰子损失用于图像分割与Unet。
我必须将图像分为两类:背景和感兴趣区域。感兴趣区域通常占整个图像像素的4%。图像约为1600 x1600像素。我发现骰子损失比交叉熵好得多。但是,如果我使用标准骰子损失公式,我的Unet无法提供正确的输出,即所有像素都被预测为背景。
对于标准骰子损失,我的意思是:

其中x_{c,i}是Unet对像素i和通道c预测的概率,y_{c,i}是对应的地面实况标签。

注意分母处x的平方。
由于某种原因,后者使网络产生正确的输出,尽管损失收敛到~0.5。
我不明白为什么后者有效而前者无效。即使我在分母上使用3的幂,后者也有效。
下面是我的实现:

def make_one_hot(labels, classes):
    one_hot = torch.FloatTensor(labels.size()[0], classes, labels.size()[2], labels.size()[3]).zero_().to(labels.device)
    target = one_hot.scatter_(1, labels.data, 1)
    return target

class DiceLoss(nn.Module):

    def __init__(self,):
        super(DiceLoss, self).__init__()

    def forward(self, output, target):

        target = make_one_hot(target.unsqueeze(dim=1), classes=output.size()[1])
        output = F.softmax(output, dim=1)

        numerator = (output * target).sum(dim=(2, 3))
        denominator = output.pow(2).sum(dim=(2, 3)) + target.sum(dim=(2, 3))

        iou = numerator / denominator

        return 1 - iou.mean()
4nkexdtk

4nkexdtk1#

Milletari等人在the paper of V-Net中提出这个建议时已经对此进行了解释,他们提出ROI可能只占整个扫描的一个非常小的区域,这很可能会偏向背景,既然你说你的ROI大约占整个图像的4%,可能你也面临着类似的问题。

js4nwp54

js4nwp542#

根据this paper,作者称“Milletari等人(2016)提出将分母改为平方形式以加快收敛速度......”他们引用的是V-Net paper,据我所知,这并没有提到他们为什么要对分母中的项求平方。我认为那篇论文中使用的公式仅仅是基于Dice系数,他们将其定义为二值分割体积,所以平方不会改变值。但是如果你使用的是softmax概率,情况就不一样了。
虽然平方项可以创造一个更平滑的损失景观,因此更快的收敛是有道理的。值得一提的是,我使用了骰子损失的非平方版本(1 - DSC),结果很好。它们仍然优化了同样的事情(即区域重叠)。但是,如果你将DSC作为一个性能指标报告,我会使用非平方版本。

相关问题