python 根据numpy数组中的值查找numpy数组的索引

4sup72z8  于 2023-01-29  发布在  Python
关注(0)|答案(3)|浏览(310)

我想找到一个更大数组中的索引,如果它们与另一个更小数组中的值相匹配的话。类似于下面的new_array

import numpy as np
summed_rows = np.random.randint(low=1, high=14, size=9999)
common_sums = np.array([7,10,13])
new_array = np.where(summed_rows == common_sums)

但是,这将返回:

__main__:1: DeprecationWarning: elementwise comparison failed; this will raise an error in the future. 
>>>new_array 
(array([], dtype=int64),)

我得到的最接近的答案是:

new_array = [np.array(np.where(summed_rows==important_sum)) for important_sum in common_sums[0]]

这给了我一个包含三个numpy数组的列表(每个“重要和”对应一个),但是每个长度不同,这会导致连接和vstacking的下游问题。需要说明的是,我不想使用上面的代码行。我想使用numpy索引summed_rows。我已经看过使用numpy.wherenumpy.argwherenumpy.intersect1d的各种答案,但是我很难把这些想法联系起来。我想我错过了一些简单的东西,问一下会更快。
提前感谢您的推荐!

uqxowvwt

uqxowvwt1#

考虑到注解中建议的选项,并使用numpy的in 1d选项添加一个额外的选项:

>>> import numpy as np
>>> summed_rows = np.random.randint(low=1, high=14, size=9999)
>>> common_sums = np.array([7,10,13])
>>> ind_1 = (summed_rows==common_sums[:,None]).any(0).nonzero()[0]   # Option of @Brenlla
>>> ind_2 = np.where(summed_rows == common_sums[:, None])[1]   # Option of @Ravi Sharma
>>> ind_3 = np.arange(summed_rows.shape[0])[np.in1d(summed_rows, common_sums)]
>>> ind_4 = np.where(np.in1d(summed_rows, common_sums))[0]
>>> ind_5 = np.where(np.isin(summed_rows, common_sums))[0]   # Option of @jdehesa

>>> np.array_equal(np.sort(ind_1), np.sort(ind_2))
True
>>> np.array_equal(np.sort(ind_1), np.sort(ind_3))
True
>>> np.array_equal(np.sort(ind_1), np.sort(ind_4))
True
>>> np.array_equal(np.sort(ind_1), np.sort(ind_5))
True

如果你计时,你会发现它们都很相似,但是@Brenlla的选项是最快的

python -m timeit -s 'import numpy as np; np.random.seed(0); a = np.random.randint(low=1, high=14, size=9999); b = np.array([7,10,13])' 'ind_1 = (a==b[:,None]).any(0).nonzero()[0]'
10000 loops, best of 3: 52.7 usec per loop

python -m timeit -s 'import numpy as np; np.random.seed(0); a = np.random.randint(low=1, high=14, size=9999); b = np.array([7,10,13])' 'ind_2 = np.where(a == b[:, None])[1]'
10000 loops, best of 3: 191 usec per loop

python -m timeit -s 'import numpy as np; np.random.seed(0); a = np.random.randint(low=1, high=14, size=9999); b = np.array([7,10,13])' 'ind_3 = np.arange(a.shape[0])[np.in1d(a, b)]'
10000 loops, best of 3: 103 usec per loop

python -m timeit -s 'import numpy as np; np.random.seed(0); a = np.random.randint(low=1, high=14, size=9999); b = np.array([7,10,13])' 'ind_4 = np.where(np.in1d(a, b))[0]'
10000 loops, best of 3: 63 usec per loo

python -m timeit -s 'import numpy as np; np.random.seed(0); a = np.random.randint(low=1, high=14, size=9999); b = np.array([7,10,13])' 'ind_5 = np.where(np.isin(a, b))[0]'
10000 loops, best of 3: 67.1 usec per loop
kd3sttzy

kd3sttzy2#

使用np.isin

import numpy as np
summed_rows = np.random.randint(low=1, high=14, size=9999)
common_sums = np.array([7, 10, 13])
new_array = np.where(np.isin(summed_rows, common_sums))
jdgnovmf

jdgnovmf3#

对于任何查找数组中不相等的数字但最接近的相等值的人来说,这是一个直接的方法,可以对不完全相等的值执行相同的操作。对于巨大的summed_rows,可能会占用大量内存。

import numpy  
    summed_rows = np.random.randint(low=1, high=14, size=9999) 
    common_sums = np.array([7,10,13])
    
    repeat_array = np.repeat(summed_rows, len(common_sums)).reshape(len(summed_rows), len(common_sums)) 
    search_index = np.argmin(np.abs(repeat_array - common_sums), axis=0)

相关问题