我有gt
和pred
图像,希望仅在mask
给定的像素子集上计算VGG损失。mask
具有与gt
相同的空间分辨率,但是ON像素不处于任何规则的几何形状中。如何做到这一点?请注意,损失是在VGG的更深层上计算的,因此激活图分辨率将小于gt
。
我能想到的一个可能的解决方案是,
gt[~mask] = const
pred[~mask] = const
torch.nn.functional.mse(VGG(gt), VGG(pred))
我相信,由于非遮罩像素已经人为地由相同的值构成,因此由于这些像素上的不匹配而导致的梯度将为0。
这是计算掩蔽VGG损失的正确方法吗?
1条答案
按热度按时间nue99wik1#
从代码的Angular 来看,你的建议是有意义的。我假设
pred
是一个生成的图像,你想通过生成它的模型反向传播梯度。现在你必须考虑这是否真的是你想做的。VGG损失是一种感知损失,旨在使
pred
和gt
看起来类似于预训练的VGG网络。如果你屏蔽了输入的区域,你只能推测VGG损失会做什么(考虑到它在训练过程中可能从未见过这样的数据)。如果你真的想使用VGG损失,我建议你自己重新训练VGG,屏蔽输入。理想情况下,掩码分布与生成任务中得到的掩码分布类似。
https://paperswithcode.com/method/vgg-loss