我想对一个numpy
数组应用一个函数,该函数遍历无穷大以获得正确的值:
def relu(x):
odds = x / (1-x)
lnex = np.log(np.exp(odds) + 1)
return lnex / (lnex + 1)
x = np.linspace(0,1,10)
np.where(x==1,1,relu(x))
字符串
正确计算
array([0.40938389, 0.43104202, 0.45833921, 0.49343414, 0.53940413,
0.60030842, 0.68019731, 0.77923729, 0.88889303, 1. ])
型
但也会发出警告:
3478817693.py:2: RuntimeWarning: divide by zero encountered in divide
odds = x / (1-x)
3478817693.py:4: RuntimeWarning: invalid value encountered in divide
return lnex / (lnex + 1)
型
如何避免警告?
请注意,性能在这里是至关重要的,所以我宁愿避免创建中间数组。
1条答案
按热度按时间kt06eoxx1#
另一个可能的解决方案,基于
np.divide
,以避免被零除。这个解决方案的灵感来自@hpaulj的评论。字符串
输出量:
型