我有两个numpy数组:a(较小)数组由int值组成,B(较大)数组由float值组成。这个想法是B包含的浮点值接近于a中的一些int值。例如,作为一个玩具的例子,我有下面的代码。数组不是这样排序的,我在a和b上都使用np.sort()来获得:
a = np.array([35, 11, 48, 20, 13, 31, 49])
b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])
对于a中的每个元素,B和中有多个浮点值,目标是为a中的每个元素获得B中最接近的值。
为了简单地实现这一点,我使用了一个for循环:
for e in a:
idx = np.abs(e - b).argsort()
print(f"{e} has nearest match = {b[idx[0]]:.4f}")
'''
11 has nearest match = 11.2890
13 has nearest match = 12.8700
20 has nearest match = 20.0500
31 has nearest match = 31.0300
35 has nearest match = 34.9900
48 has nearest match = 48.1000
49 has nearest match = 49.2000
'''
- a中的值可能不存在于b中,反之亦然。*
a.size = 2040和B.size = 1041901
构建KD树:
# Construct KD-Tree using and query nearest neighnor-
kd_tree = KDTree(data = np.expand_dims(a, 1))
dist_nn, idx_nn = kd_tree.query(x = np.expand_dims(b, 1), k = [1])
dist.shape, idx.shape
# ((19, 1), (19, 1))
为了得到“b”相对于“a”的最近邻,我这样做:
b[idx]
'''
array([[10.7 ],
[10.7 ],
[10.7 ],
[11.289],
[11.289],
[11.289],
[11.3 ],
[11.3 ],
[11.3 ],
[12.32 ],
[12.32 ],
[12.32 ],
[12.87 ],
[12.87 ],
[12.87 ],
[12.87 ],
[13.5 ],
[13.5 ],
[18.78 ]])
'''
问题:
- 似乎KD树在'a'中的值不会超过20。[31,25,48,49]在A中完全错过
- 与for循环的输出相比,它找到的大多数最近邻居都是错误的!!
出什么事了?
1条答案
按热度按时间hfyxw5xn1#
如果你想为
a
中的每个条目获取最接近的元素,你可以为b
构建KD树,然后查询a
。