numpy 避免RuntimeWarning使用where

xtfmy6hx  于 2024-01-08  发布在  其他
关注(0)|答案(1)|浏览(88)

我想对一个numpy数组应用一个函数,该函数遍历无穷大以获得正确的值:

def relu(x):
    odds = x / (1-x)
    lnex = np.log(np.exp(odds) + 1)
    return lnex / (lnex + 1)

x = np.linspace(0,1,10)
np.where(x==1,1,relu(x))

字符串
正确计算

array([0.40938389, 0.43104202, 0.45833921, 0.49343414, 0.53940413,
       0.60030842, 0.68019731, 0.77923729, 0.88889303, 1.        ])


但也会发出警告:

3478817693.py:2: RuntimeWarning: divide by zero encountered in divide
  odds = x / (1-x)
3478817693.py:4: RuntimeWarning: invalid value encountered in divide
  return lnex / (lnex + 1)

如何避免警告?

请注意,性能在这里是至关重要的,所以我宁愿避免创建中间数组。

kt06eoxx

kt06eoxx1#

另一个可能的解决方案,基于np.divide,以避免被零除。这个解决方案的灵感来自@hpaulj的评论。

def relu(x):
    odds = np.divide(x, 1-x, out=np.zeros_like(x), where=x!=1)
    lnex = np.log(np.exp(odds) + 1)
    return lnex / (lnex + 1)

x = np.linspace(0,1,10)
np.where(x==1,1,relu(x))

字符串
输出量:

array([0.40938389, 0.43104202, 0.45833921, 0.49343414, 0.53940413,
       0.60030842, 0.68019731, 0.77923729, 0.88889303, 1.        ])

相关问题