从数组大小可能小于N的numpy数组中获取N个最小值

xzlaal3s  于 2023-01-26  发布在  其他
关注(0)|答案(2)|浏览(141)

我尝试使用numpy.argpartition从数组中获取n最小值。但是,我不能保证数组中至少有n值。如果少于n值,我只需要整个数组。
目前,我通过检查数组大小来处理这个问题,但是我觉得我缺少了一个可以避免这种分支检查的本地numpy方法。

if np.size(arr) < N: 
    return arr 
else:
    return arr[np.argpartition(arr, N)][:N]

最小可重现示例:

import numpy as np

#Find the 4 smallest values in the array
#Arrays can be arbitrarily sized, as it's the result of finding all elements in a larger array
# that meet a threshold
small_arr = np.array([3,1,4])
large_arr = np.array([3,1,4,5,0,2])

#For large_arr, I can use np.argpartition just fine:
large_idx = np.argpartition(large_arr, 4)
#large_idx now contains array([4, 5, 1, 0, 2, 3])

#small_arr results in an indexing error doing the same thing:
small_idx = np.argpartition(small_arr, 4)
#ValueError: kth(=4) out of bounds (3)

我已经浏览了numpy文档中的截断、最大长度和其他类似的术语,但是没有出现我需要的东西。

0g0grzrc

0g0grzrc1#

一种方法(取决于您的情况)是,当数组较短时,将arg限制为min

return arr[np.argpartition(arr, min(N, arr.size - 1)][:N]

切片允许比数组长度更大的值,只有argpartition需要min()检查。
这比你的分支版本效率要低,因为即使你只想要整个数组,它也必须执行argpartition,但它更简洁,所以这取决于你的优先级是什么--我个人可能会保留分支或使用三进制:

return arr if arr.size < N else arr[np.argpartition(arr, N)][:N]
rm5edbpk

rm5edbpk2#

您可以尝试:

def nsmallest(array, n):
    return array[np.argsort(array)[:n]]

其中n是所需的最小值的数量。

相关问题