我尝试使用numpy.argpartition
从数组中获取n
最小值。但是,我不能保证数组中至少有n
值。如果少于n
值,我只需要整个数组。
目前,我通过检查数组大小来处理这个问题,但是我觉得我缺少了一个可以避免这种分支检查的本地numpy方法。
if np.size(arr) < N:
return arr
else:
return arr[np.argpartition(arr, N)][:N]
最小可重现示例:
import numpy as np
#Find the 4 smallest values in the array
#Arrays can be arbitrarily sized, as it's the result of finding all elements in a larger array
# that meet a threshold
small_arr = np.array([3,1,4])
large_arr = np.array([3,1,4,5,0,2])
#For large_arr, I can use np.argpartition just fine:
large_idx = np.argpartition(large_arr, 4)
#large_idx now contains array([4, 5, 1, 0, 2, 3])
#small_arr results in an indexing error doing the same thing:
small_idx = np.argpartition(small_arr, 4)
#ValueError: kth(=4) out of bounds (3)
我已经浏览了numpy文档中的截断、最大长度和其他类似的术语,但是没有出现我需要的东西。
2条答案
按热度按时间0g0grzrc1#
一种方法(取决于您的情况)是,当数组较短时,将arg限制为
min
:切片允许比数组长度更大的值,只有
argpartition
需要min()
检查。这比你的分支版本效率要低,因为即使你只想要整个数组,它也必须执行
argpartition
,但它更简洁,所以这取决于你的优先级是什么--我个人可能会保留分支或使用三进制:rm5edbpk2#
您可以尝试:
其中
n
是所需的最小值的数量。