我有一个关于KL发散损失的快速问题,因为我在研究过程中看到了许多不同的实现。最常见的两个是这两个。然而,在查看数学方程时,我不确定是否应该包括平均值。
KL_loss = -0.5 * torch.sum(1 + torch.log(sigma**2) - mean**2 - sigma**2)
OR
KL_loss = -0.5 * torch.sum(1 + torch.log(sigma**2) - mean**2 - sigma**2)
KL_loss = torch.mean(KL_loss)
谢谢大家!
1条答案
按热度按时间daolsyd01#
此处使用的公式计算单个示例的损失:
对于批量数据,我们需要计算多个示例的损失。
使用我们的每个示例公式,我们得到多个损失值,每个示例1个。我们需要一些方法将每个示例损失计算减少到单个标量值。最常见的是,您希望取批处理的平均值。您将看到大多数pytorch's loss functions使用
reduction="mean"
。取平均值而不是总和的优点是我们的损失成为批量大小不变的(即不随批量而缩放)。从带有实现的stackoverflow post you linked中,您将看到第一个和第二个链接的实现取批处理的平均值(即除以批处理大小)。
第三个链接的实现不仅取批处理的平均值,而且取sigma/mu向量本身的平均值:
因此,不是按
1/N
(其中N
是批大小)缩放总和,而是按1/(NM)
(其中M
是mu
和sigma
向量的维数)缩放总和。你的损失是批量大小和潜在维度大小不变的。重要的是要注意,缩放你的损失不会改变 "形状"(最优点保持固定),它只是缩放它(你可以通过学习率控制如何逐步通过)。