pytorch 如何计算掩蔽VGG损失

xdnvmnnf  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(108)

我有gtpred图像,希望仅在mask给定的像素子集上计算VGG损失。mask具有与gt相同的空间分辨率,但是ON像素不处于任何规则的几何形状中。如何做到这一点?请注意,损失是在VGG的更深层上计算的,因此激活图分辨率将小于gt
我能想到的一个可能的解决方案是,

gt[~mask] = const
pred[~mask] = const

torch.nn.functional.mse(VGG(gt), VGG(pred))

我相信,由于非遮罩像素已经人为地由相同的值构成,因此由于这些像素上的不匹配而导致的梯度将为0。
这是计算掩蔽VGG损失的正确方法吗?

nue99wik

nue99wik1#

从代码的Angular 来看,你的建议是有意义的。我假设pred是一个生成的图像,你想通过生成它的模型反向传播梯度。
现在你必须考虑这是否真的是你想做的。VGG损失是一种感知损失,旨在使predgt看起来类似于预训练的VGG网络。如果你屏蔽了输入的区域,你只能推测VGG损失会做什么(考虑到它在训练过程中可能从未见过这样的数据)。
如果你真的想使用VGG损失,我建议你自己重新训练VGG,屏蔽输入。理想情况下,掩码分布与生成任务中得到的掩码分布类似。
https://paperswithcode.com/method/vgg-loss

相关问题