交叉熵损失的PyTorch LogSoftmax与Softmax

w8f9ii69  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(202)

据我所知,PyTorch的LogSoftmax函数基本上只是一种计算Log(Softmax(x))的数值稳定性更高的方法。Softmax允许您将线性层的输出转换为分类概率分布。
pytorch documentation表示交叉熵损失将nn.LogSoftmax()nn.NLLLoss()组合在一个类中。
看到NLLLoss,我还是很困惑......是否使用了2个日志?我认为负日志是事件的信息内容。(如entropy
经过进一步的研究,我认为NLLLoss假设你实际上是在传递对数概率而不仅仅是概率。这是正确的吗?如果是这样的话,有点奇怪...

tag5nh1u

tag5nh1u1#

是的,NLLLoss将对数概率(log(softmax(x)))作为输入。为什么?因为如果您添加nn.LogSoftmax(或F.log_softmax)作为模型输出的最后一层,您可以使用torch.exp(output)轻松获得概率,并且为了获得交叉熵损失,您可以直接使用nn.NLLLoss。当然,正如您所说,log-softmax更稳定。
而且,只有一个日志(它在nn.LogSoftmax中)。nn.NLLLoss中没有日志。
nn.CrossEntropyLoss()nn.LogSoftmax()log(softmax(x)))和nn.NLLLoss()合并为一个类。因此,传递到nn.CrossEntropyLoss的网络输出需要是网络的原始输出(称为logits),而不是softmax函数的输出。

相关问题