计算2个numpy数组之间的最近邻- KDTree

ve7v8dk2  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(87)

我有两个numpy数组:a(较小)数组由int值组成,B(较大)数组由float值组成。这个想法是B包含的浮点值接近于a中的一些int值。例如,作为一个玩具的例子,我有下面的代码。数组不是这样排序的,我在a和b上都使用np.sort()来获得:

a = np.array([35, 11, 48, 20, 13, 31, 49])
b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])

对于a中的每个元素,B和中有多个浮点值,目标是为a中的每个元素获得B中最接近的值。
为了简单地实现这一点,我使用了一个for循环:

for e in a:
    idx = np.abs(e - b).argsort()
    print(f"{e} has nearest match = {b[idx[0]]:.4f}")
'''
11 has nearest match = 11.2890
13 has nearest match = 12.8700
20 has nearest match = 20.0500
31 has nearest match = 31.0300
35 has nearest match = 34.9900
48 has nearest match = 48.1000
49 has nearest match = 49.2000
'''
  • a中的值可能不存在于b中,反之亦然。*
    a.size = 2040和B.size = 1041901

构建KD树:

# Construct KD-Tree using and query nearest neighnor-
kd_tree = KDTree(data = np.expand_dims(a, 1))
dist_nn, idx_nn = kd_tree.query(x = np.expand_dims(b, 1), k = [1])

dist.shape, idx.shape
# ((19, 1), (19, 1))

为了得到“b”相对于“a”的最近邻,我这样做:

b[idx]
'''
array([[10.7  ],
       [10.7  ],
       [10.7  ],
       [11.289],
       [11.289],
       [11.289],
       [11.3  ],
       [11.3  ],
       [11.3  ],
       [12.32 ],
       [12.32 ],
       [12.32 ],
       [12.87 ],
       [12.87 ],
       [12.87 ],
       [12.87 ],
       [13.5  ],
       [13.5  ],
       [18.78 ]])
'''

问题:

  • 似乎KD树在'a'中的值不会超过20。[31,25,48,49]在A中完全错过
  • 与for循环的输出相比,它找到的大多数最近邻居都是错误的!!

出什么事了?

hfyxw5xn

hfyxw5xn1#

如果你想为a中的每个条目获取最接近的元素,你可以为b构建KD树,然后查询a

from scipy import spatial

kd = spatial.KDTree(b[:,np.newaxis])
distances, indices = kd.query(a[:, np.newaxis])
values = b[indices]

for ai, bi in zip(a, values):
    print(f"{ai} has nearest match = {bi:.4f}")
35 has nearest match = 34.9900
11 has nearest match = 11.2890
48 has nearest match = 48.1000
20 has nearest match = 20.0500
13 has nearest match = 12.8700
31 has nearest match = 31.0300
49 has nearest match = 49.2000

相关问题