Numpy平均值根据行顺序给出略有不同的结果

qjp7pelc  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(156)

在一个测试用例中,我们使用np.testing.assert_allclose来确定两个数据源是否在平均值上彼此一致。但是,尽管具有不同顺序的相同数据,计算的平均值略有不同。下面是一个最短的工作示例:

  1. import numpy as np
  2. x = np.array(
  3. [[0.5224021, 0.8526993], [0.6045113, 0.7965965], [0.5053657, 0.86290526], [0.70609194, 0.7081201]],
  4. dtype=np.float32,
  5. )
  6. y = np.array(
  7. [[0.5224021, 0.8526993], [0.70609194, 0.7081201], [0.6045113, 0.7965965], [0.5053657, 0.86290526]],
  8. dtype=np.float32,
  9. )
  10. print("X mean", x.mean(0))
  11. print("Y mean", y.mean(0))
  12. z = x[[0, 3, 1, 2]]
  13. print("Z", z)
  14. print("Z mean", z.mean(0))
  15. np.testing.assert_allclose(z.mean(0), y.mean(0))
  16. np.testing.assert_allclose(x.mean(0), y.mean(0))

使用Python 3.10.6和NumPy 1.24.2,给出以下输出:

  1. X mean [0.58459276 0.8050803 ]
  2. Y mean [0.5845928 0.8050803]
  3. Z [[0.5224021 0.8526993 ]
  4. [0.70609194 0.7081201 ]
  5. [0.6045113 0.7965965 ]
  6. [0.5053657 0.86290526]]
  7. Z mean [0.5845928 0.8050803]
  8. Traceback (most recent call last):
  9. File "/home/nuric/semafind-db/scribble.py", line 19, in <module>
  10. np.testing.assert_allclose(x.mean(0), y.mean(0))
  11. File "/home/nuric/semafind-db/.venv/lib/python3.10/site-packages/numpy/testing/_private/utils.py", line 1592, in assert_allclose
  12. assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  13. File "/usr/lib/python3.10/contextlib.py", line 79, in inner
  14. return func(*args, **kwds)
  15. File "/home/nuric/semafind-db/.venv/lib/python3.10/site-packages/numpy/testing/_private/utils.py", line 862, in assert_array_compare
  16. raise AssertionError(msg)
  17. AssertionError:
  18. Not equal to tolerance rtol=1e-07, atol=0
  19. Mismatched elements: 1 / 2 (50%)
  20. Max absolute difference: 5.9604645e-08
  21. Max relative difference: 1.0195925e-07
  22. x: array([0.584593, 0.80508 ], dtype=float32)
  23. y: array([0.584593, 0.80508 ], dtype=float32)

一个解决方案是减少对Assert的容忍度,但有什么想法为什么会发生这种情况吗?

dfddblmv

dfddblmv1#

你应该使用np.float64来获得更高的精度,根据我的经验,np.float32适用于小数点后3位的数字。这段代码将工作:

  1. import numpy as np
  2. x = np.array(
  3. [[0.5224021, 0.8526993], [0.6045113, 0.7965965], [0.5053657, 0.86290526], [0.70609194, 0.7081201]],
  4. dtype=np.float64,
  5. )
  6. y = np.array(
  7. [[0.5224021, 0.8526993], [0.70609194, 0.7081201], [0.6045113, 0.7965965], [0.5053657, 0.86290526]],
  8. dtype=np.float64,
  9. )
  10. print("X mean", x.mean(0))
  11. print("Y mean", y.mean(0))
  12. z = x[[0, 3, 1, 2]]
  13. print("Z", z)
  14. print("Z mean", z.mean(0))
  15. np.testing.assert_allclose(z.mean(0), y.mean(0))
  16. np.testing.assert_allclose(x.mean(0), y.mean(0))

你可以做的另一件事是增加容忍度:

  1. import numpy as np
  2. x = np.array(
  3. [[0.5224021, 0.8526993], [0.6045113, 0.7965965], [0.5053657, 0.86290526], [0.70609194, 0.7081201]],
  4. dtype=np.float32,
  5. )
  6. y = np.array(
  7. [[0.5224021, 0.8526993], [0.70609194, 0.7081201], [0.6045113, 0.7965965], [0.5053657, 0.86290526]],
  8. dtype=np.float32,
  9. )
  10. print("X mean", x.mean(0))
  11. print("Y mean", y.mean(0))
  12. z = x[[0, 3, 1, 2]]
  13. print("Z", z)
  14. print("Z mean", z.mean(0))
  15. np.testing.assert_allclose(z.mean(0), y.mean(0), rtol=1e-6)
  16. np.testing.assert_allclose(x.mean(0), y.mean(0), rtol=1e-6)

最后,这个错误的发生是因为它们的总和在3种情况下都是以不同的顺序完成的,因此每个数字都会有轻微的差异,因为它们将四舍五入到np.float32。你可以通过打印更多的小数位来看到:

  1. import numpy as np
  2. np.set_printoptions(formatter={'float': lambda x: "{0:0.10f}".format(x)})
  3. x = np.array(
  4. [[0.5224021, 0.8526993], [0.6045113, 0.7965965], [0.5053657, 0.86290526], [0.70609194, 0.7081201]],
  5. dtype=np.float32,
  6. )
  7. y = np.array(
  8. [[0.5224021, 0.8526993], [0.70609194, 0.7081201], [0.6045113, 0.7965965], [0.5053657, 0.86290526]],
  9. dtype=np.float32,
  10. )
  11. print("X mean", x.mean(0))
  12. print("Y mean", y.mean(0))
  13. z = x[[0, 3, 1, 2]]
  14. print("Z", z)
  15. print("Z mean", z.mean(0))
  16. np.testing.assert_allclose(z.mean(0), y.mean(0), rtol=1e-6)
  17. np.testing.assert_allclose(x.mean(0), y.mean(0), rtol=1e-6)

它将打印:

  1. X mean [0.5845927596 0.8050802946]
  2. Y mean [0.5845928192 0.8050802946]
  3. Z [[0.5224021077 0.8526992798]
  4. [0.7060919404 0.7081201077]
  5. [0.6045113206 0.7965965271]
  6. [0.5053657293 0.8629052639]]
  7. Z mean [0.5845928192 0.8050802946]
展开查看全部

相关问题