python KL发散损失方程

9wbgstp7  于 2022-12-21  发布在  Python
关注(0)|答案(1)|浏览(223)

我有一个关于KL发散损失的快速问题,因为我在研究过程中看到了许多不同的实现。最常见的两个是这两个。然而,在查看数学方程时,我不确定是否应该包括平均值。

KL_loss = -0.5 * torch.sum(1 + torch.log(sigma**2) - mean**2 - sigma**2)

OR 

KL_loss = -0.5 * torch.sum(1 + torch.log(sigma**2) - mean**2 - sigma**2)
KL_loss = torch.mean(KL_loss)

谢谢大家!

daolsyd0

daolsyd01#

此处使用的公式计算单个示例的损失:

对于批量数据,我们需要计算多个示例的损失。
使用我们的每个示例公式,我们得到多个损失值,每个示例1个。我们需要一些方法将每个示例损失计算减少到单个标量值。最常见的是,您希望取批处理的平均值。您将看到大多数pytorch's loss functions使用reduction="mean"。取平均值而不是总和的优点是我们的损失成为批量大小不变的(即不随批量而缩放)。
从带有实现的stackoverflow post you linked中,您将看到第一个和第二个链接的实现取批处理的平均值(即除以批处理大小)。

KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
...
(BCE + KLD) / x.size(0)
KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
...
(NLL_loss + KL_weight * KL_loss) / batch_size

第三个链接的实现不仅取批处理的平均值,而且取sigma/mu向量本身的平均值:

0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

因此,不是按1/N(其中N是批大小)缩放总和,而是按1/(NM)(其中Mmusigma向量的维数)缩放总和。你的损失是批量大小和潜在维度大小不变的。重要的是要注意,缩放你的损失不会改变 "形状"(最优点保持固定),它只是缩放它(你可以通过学习率控制如何逐步通过)。

相关问题