bug描述 Describe the Bug
利用 paddle.nn.functional.softmax
对一个几乎对称的 logits 取softmax操作,在输入logits 对角线值一样的情况下, 输出 对角线上的值不一致,复现代码如下
import paddle as p
from paddle.nn import functional as PF
import torch
from torch.nn import functional as TF
print(f'torch-versoin: {torch.version.git_version}')
print(f'paddle-version: {p.__git_commit__}')
print()
seqlen=8
x = t.arange(seqlen).long()
y = t.arange(seqlen).long()
x = x % 8
y = (y + 1) % 8
x = TF.one_hot(x, num_classes=8).float()
y = TF.one_hot(y, num_classes=8).float()
x = 2 * x + y
x = x.cuda()
torch_res = TF.softmax(x, -1).cpu().numpy()
seqlen=8
x = p.arange(seqlen).cast('int64')
y = p.arange(seqlen).cast('int64')
x = x % 8
y = (y + 1) % 8
x = PF.one_hot(x, num_classes=8)
y = PF.one_hot(y, num_classes=8)
x = 2 * x + y
paddle_res = PF.softmax(x, -1).numpy()
print(f'torch-vs-paddle diff: {paddle_res-torch_res}')
paddle_diag = paddle_res[range(8), range(8)]
torch_diag = torch_res[range(8), range(8)]
print()
print(f'paddle-对角线:{paddle_diag}')
print(f'torch-对角线:{torch_diag}')
运行结果
torch-versoin: 49444c3e546bf240bed24a101e747422d1f8a0ee
paddle-version: fd48f88b46d66c536ee3da0a373380746b2d1f05
torch-vs-paddle diff: [[-5.9604645e-08 0.0000000e+00 -7.4505806e-09 -7.4505806e-09
-7.4505806e-09 -7.4505806e-09 -7.4505806e-09 -7.4505806e-09]
[-7.4505806e-09 -5.9604645e-08 0.0000000e+00 -7.4505806e-09
-7.4505806e-09 -7.4505806e-09 -7.4505806e-09 -7.4505806e-09]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.4901161e-08
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
1.4901161e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[-7.4505806e-09 -7.4505806e-09 -7.4505806e-09 -7.4505806e-09
-5.9604645e-08 0.0000000e+00 -7.4505806e-09 -7.4505806e-09]
[-7.4505806e-09 -7.4505806e-09 -7.4505806e-09 -7.4505806e-09
-7.4505806e-09 -5.9604645e-08 0.0000000e+00 -7.4505806e-09]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 1.4901161e-08]
[ 1.4901161e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
paddle-对角线:[0.45873845 0.45873845 0.4587385 0.4587385 0.45873845 0.45873845
0.4587385 0.4587385 ]
torch-对角线:[0.4587385 0.4587385 0.4587385 0.4587385 0.4587385 0.4587385 0.4587385
0.4587385]
可见,在输入主对角线 值全为2.0的情况下,paddle softmax结果的对角线值并不一致(有的是 0.45873845
, 有的是 0.4587385
)
其他补充信息 Additional Supplementary Information
No response
1条答案
按热度按时间pw9qyyiw1#
用上面这个自己写的softmax替代,精度能完整对齐。