PyTorch MSE损失与直接计算相差2倍

41ik7eoe  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(147)

为什么torch.nn.functional.mse_loss(x1,x2)的结果与MSE的直接计算结果不同?
我要复制的测试代码:

  1. import torch
  2. import numpy as np
  3. # Think of x1 as predicted 2D coordinates and x2 of ground truth
  4. x1 = torch.rand(10,2)
  5. x2 = torch.rand(10,2)
  6. mse_torch = torch.nn.functional.mse_loss(x1,x2)
  7. print(mse_torch) # 0.1557
  8. mse_direct = torch.nn.functional.pairwise_distance(x1,x2).square().mean()
  9. print(mse_direct) # 0.3314
  10. mse_manual = 0
  11. for i in range(len(x1)) :
  12. mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
  13. print(mse_manual) # 0.3314

正如我们所看到的,torch的mse_loss的结果是0.1557,与手工MSE计算产生0.3314不同。

**实际上,mse_loss的结果正好等于直接结果乘以点的维数(这里是2)。

这是怎么回事

jgzswidk

jgzswidk1#

不同之处在于torch.nn.functional.mse_loss(x1,x2)在计算平方误差时不对坐标应用求和运算。然而,torch.nn.functional.pairwise_distance和np.linalg.norm对坐标应用求和运算。可以通过以下方式重现计算的mse值:

  1. import torch
  2. import numpy as np
  3. x1 = torch.rand(10,2)
  4. x2 = torch.rand(10,2)
  5. mse_torch = torch.nn.functional.mse_loss(x1,x2)
  6. print(mse_torch) # 0.1557
  7. mse_manual = 0
  8. x3 = torch.zeros(10,2)
  9. for i in range(len(x1)) :
  10. x3[i,:1] +=(torch.nn.functional.pairwise_distance(x1[i,:1],x2[i,:1],eps=0.0)**2)/len(x1)
  11. x3[i,1:] += (torch.nn.functional.pairwise_distance(x1[i,1:],x2[i,1:],eps=0.0)**2)/len(x1)
  12. mse_manual += x3[i]
  13. print(mse_manual.mean()) # 0.1557
  14. mse_manual = 0
  15. for i in range(len(x1)) :
  16. mse_manual += np.square(x1[i]-x2[i]) / len(x1)
  17. print(mse_manual.mean()) # 0.1557

或者,如果你想使用修改的MSE损失来重现成对距离函数,你可以通过以下方式来实现:

  1. import torch
  2. import numpy as np
  3. # Think of x1 as predicted 2D coordinates and x2 of ground truth
  4. x1 = torch.rand(10,2)
  5. x2 = torch.rand(10,2)
  6. mse_torch = torch.nn.functional.mse_loss(x1,x2, reduction='none')
  7. print(mse_torch.sum(-1).mean()) # 0.3314
  8. mse_direct =
  9. torch.nn.functional.pairwise_distance(x1,x2).square().mean()
  10. print(mse_direct) # 0.3314
  11. mse_manual = 0
  12. for i in range(len(x1)) :
  13. mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
  14. print(mse_manual) # 0.3314
展开查看全部

相关问题