pandas 对于1D numpy数组中的每个元素,从第二个数组中找到最低索引元素,使得值和索引都大于第一个

hgncfbus  于 2023-10-14  发布在  其他
关注(0)|答案(2)|浏览(118)

我有两个长度约为100万的1D numpy数组。对于第一个数组中的每个元素x,我希望第二个数组中的最低索引元素y,使得y > x且arg(y)> arg(x),即索引y >索引x。
假设
arr1 = [3,2,1,0,3,2,3] arr2 = [0,2,1,2,3,4,6,5]
输出为-
值= [4,3,2,3,4,6,5]索引= [5,4,3,4,5,6,7]
完全迷路了。无法使用广播,因为数组长度为100万。我已经尝试广播以获得元素,使得y > x,但我不知道如何应用第二个条件。

k5ifujac

k5ifujac1#

  • 对于相对较小的数组:*

可以将arr2广播到arr1,以通过值(arr2 > arr1[:, None])和索引(np.indices(arr2.shape)[0] > np.indices(arr1.shape)[0][:, None])执行>(* 大于 *)比较。
这涉及到在arr1上放置一个新的轴,使其成为一个列向量,并最终创建2个布尔掩码。
这些掩码的交集允许找到具有np.argmax的第一个有效索引,然后将其用于采用具有np.take的值。
整个解决方案如下:

indices = (np.argmax((arr2 > arr1[:, None]) 
           & (np.indices(arr2.shape)[0] > np.indices(arr1.shape)[0][:, None]), axis=1))
values = np.take(arr2, indices)
indices
array([5, 4, 3, 4, 5, 6, 7])

values
array([4, 3, 2, 3, 4, 6, 5])
  • 对于第百万个阵列:*

我建议使用numbanjit + parallel来实现高性能计算:

from numba import njit, prange

@njit(parallel=True)
def arrange_arrays(a1, a2):
    indices = np.empty_like(a1)
    values = np.empty_like(a1)

    for i in prange(len(a1)):
        for j in prange(i + 1, len(a2)):
            if j > i and a2[j] > a1[i]:
                values[i] = a2[j]
                indices[i] = j
                break

    return (values, indices)

arr1 = np.random.randint(0, 1000_000, 1000_1000)
arr2 = np.random.randint(0, 1000_000, 1000_1000)

values, indices = arrange_arrays(arr1, arr2)
print(values[:30], indices[:30])
[428072 628854 791813 796369 924274 901268 732129 949403 949403 800018
 985390 750386 602616 791772 870306 870306 857630 759459 335358 881248
 881248 985390 985390 773380 940193 987427 987427 987427 386893] [ 1  2  3  4  5  6  7  9  9 10 23 12 13 14 16 16 17 18 20 21 21 23 23 24
 25 31 31 31 29 31]
pkmbmrz7

pkmbmrz72#

在遍历arr 1时,可以对arr 2进行切片,以便切片仅包含超过arr 1中当前值的索引的索引。然后,您可以沿着arr 2的切片,直到找到一个超过arr 1的当前值的值。在这种情况下,您只考虑arr 2中满足这两个要求的项;你执行最小数量的比较;你的平均最大比较次数(在arr 1中的值上)是arr 1的长度除以2;总的最大比较次数是n*(n-1)/2,其中n是arr 1的长度;并且预期的比较次数是最大值的一半。

arr1 = [3,2,1,0,3,2,3] 
arr2 = [0,2,1,2,3,4,6,5]
arr_val = []
arr_ind = []
for xindx, x in enumerate(arr1):
    for yindx, y in enumerate(arr2[xindx+1:]):
        if (y>x):
            arr_val.append(y)
            arr_ind.append(yindx+xindx+1)
            break
# [4, 3, 2, 3, 4, 6, 5], [5, 4, 3, 4, 5, 6, 7]

这段代码将产生您需要的结果
[四、三、二、三、四、六、五] [五、四、三、四、五、六、七]
另一种方法使用.where()

import numpy as np
np_arr1 = np.array([3,2,1,0,3,2,3])
np_arr2 = np.array([0,2,1,2,3,4,6,5])
np_arr_val = np.array([])
np_arr_ind = np.array([])
for xindx, x in enumerate(np_arr1):
    yloc = np.where(np_arr2[xindx+1:] > x)[0][0] + xindx+1
    np_arr_val = np.append(np_arr_val, np_arr2[yloc])
    np_arr_ind = np.append(np_arr_ind, yloc)
print(np_arr_ind)
print(np_arr_val)

这产生相同的结果。

相关问题