我不明白Pytorch是如何进行日志规范化的,我找不到一个好的例子/解释。有人能提供一个解释吗?
eg
input_tensor =torch.tensor([0.6,0.4])
m = dist.Categorical(logits = input_tensor)
print(np.log(input_tensor))
print(m.logits)
gives:
tensor([-0.5981, -0.7981])
tensor([-0.5108, -0.9163])
字符串
我的概率之和为1,所以没有什么需要标准化的,但Pytorch正在转换我的输入。
Pytorch的文档说:
- logits参数将被解释为未标准化的对数概率,因此可以是任何真实的数字。它同样将被标准化,以便得到的概率之和沿最后一个维度沿着为1。logits将返回此标准化值。*
1条答案
按热度按时间tct7dpnv1#
np.log
函数通过表达式计算值,
因此,如果输入值是
[0.6, 0.4]
,则结果输出将是tensor([-0.5108, -0.9163])
。在处理
Categorical
时,您可以在这里探索PyTorchcategorical.py
源代码:https://github.com/pytorch/pytorch/blob/main/torch/distributions/categorical.py)。你会发现一行代码:
字符串
考虑输入Tensor为
[0.6, 0.4]
,因此logits.logsumexp(dim=-1, keepdim=True)
的值为的数据
因此,这就是为什么
m.logits
产生tensor([-0.5981, -0.7981])
。