Numpy apply_沿着_axis推断出错误的数据类型

1cosmwyk  于 2023-06-23  发布在  其他
关注(0)|答案(3)|浏览(158)

我在使用NumPy时遇到以下问题:
代码:

import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)

输出:

['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!

我可以看到NumPy以某种方式从函数返回的第一个值推断数据类型。我想出了以下解决方案-从函数中返回NumPy数组,显式声明dtype而不是string,并重塑结果:

def get_label_2(x):
    if x.sum() <= 10:
        return np.array(['SMALL'], dtype='|S5')
    else:
        return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])

你知道这个问题的更优雅的解决方案吗?

4uqofj5v

4uqofj5v1#

可以使用np.where

arr1 = np.array([[1, 2], [30, 40]])
arr2 = np.array([[30, 40], [1, 2]])

print(np.where(arr1.sum(axis=1)<=10,'SMALL','BIG'))
print(np.where(arr2.sum(axis=1)<=10,'SMALL','BIG'))
['SMALL' 'BIG']
['BIG' 'SMALL']

在函数中:

def get_label(x, threshold, axis=1, label1='SMALL', label2='BIG'):
    return np.where(x.sum(axis=axis) <= threshold, label1, label2)
9wbgstp7

9wbgstp72#

apply_along_axis不是一个优雅的解决方案;很方便,但不快。本质上是这样的

In [277]: get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
In [279]: np.array([get_label(row) for row in np.array([[30,40],[1,2]])])
Out[279]: 
array(['BIG', 'SMALL'],
      dtype='<U5')
In [280]: res = np.zeros((2,),dtype='S5')
In [281]: arr = np.array([[30,40],[1,2]])
In [282]: for i in range(2):
     ...:     res[i] = get_label(arr[i,:])
     ...:     
In [283]: res
Out[283]: 
array([b'BIG', b'SMALL'],
      dtype='|S5')

除了它概括了形状并推导出resdtype。
对于这样一个简单的“遍历行”的情况,你也可以这样做:

In [278]: np.array([get_label(row) for row in np.array([[1,2],[30,40]])])
Out[278]: 
array(['SMALL', 'BIG'],
      dtype='<U5')
In [279]: np.array([get_label(row) for row in np.array([[30,40],[1,2]])])
Out[279]: 
array(['BIG', 'SMALL'],
      dtype='<U5')

优雅的解决方案是避免Python级别的循环,显式或隐藏,而是使用编译的数组方法,例如给sum一个轴:

In [284]: arr.sum(axis=1)
Out[284]: array([70,  3])
disbfnqx

disbfnqx3#

对我来说,最优雅的解决方案是将numpy数组转换为pandas DataFrame,然后使用pandas.DataFrame.apply函数,它不会执行任何不必要的转换:

import numpy as np
import pandas as pd
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[30, 40], [1, 2]])
df = pd.DataFrame(arr) # convert numpy array to pandas dataframe
arr2 = df.apply(get_label,1).to_numpy() # apply function and convert back to numpy array
print(arr2)

输出:

['BIG' 'SMALL']

相关问题