我在使用NumPy时遇到以下问题:
代码:
import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)
输出:
['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!
我可以看到NumPy以某种方式从函数返回的第一个值推断数据类型。我想出了以下解决方案-从函数中返回NumPy数组,显式声明dtype而不是string,并重塑结果:
def get_label_2(x):
if x.sum() <= 10:
return np.array(['SMALL'], dtype='|S5')
else:
return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])
你知道这个问题的更优雅的解决方案吗?
3条答案
按热度按时间4uqofj5v1#
可以使用
np.where
:在函数中:
9wbgstp72#
apply_along_axis
不是一个优雅的解决方案;很方便,但不快。本质上是这样的除了它概括了形状并推导出
res
dtype。对于这样一个简单的“遍历行”的情况,你也可以这样做:
优雅的解决方案是避免Python级别的循环,显式或隐藏,而是使用编译的数组方法,例如给
sum
一个轴:disbfnqx3#
对我来说,最优雅的解决方案是将numpy数组转换为pandas DataFrame,然后使用pandas.DataFrame.apply函数,它不会执行任何不必要的转换:
输出: