过滤掉NumPy中接近的值

pbpqsu0x  于 2022-11-10  发布在  其他
关注(0)|答案(1)|浏览(246)

对于一个小项目,我需要在一个数值数组中筛选出距离很近的值(距离由一个参数给出)。NumPy数组中的值始终是排序的,并且每个值都是唯一的。
让我们来看一个例子:

sample_array = np.array([1, 7, 15, 16, 18, 19, 26, 33])

对于给定的距离,我想删除与前一个值太接近的所有值。所以

dist = 1

# Desired result (remove 16 and 19)

result = np.array([1, 7, 15, 18, 26, 33])

dist = 3

# Desired result (remove 16 and 18)

result = np.array([1, 7, 15, 19, 26, 33])

现在,我已经在数组上实现了这个循环。因为我还有一些更大的数组,所以我想知道是否有更有效的解决方案。
编辑:仅供参考,以下是我目前的实现:

it = np.nditer(sample_array[1:])
result_list = [sample_array[0]]

for i in it:
    if (i - result_list [-1]) > delta:
        result_list .append(i)

result = np.array(result_list)

朱尔茨

ghhkc1vu

ghhkc1vu1#

这一点的任何实现都将涉及一个循环,现在在Python语言中,该循环将如下所示:

import numpy as np

def rem_dist(arr,dist):
    new = [arr[0]]
    for n in arr[1:]:
        if n - new[-1] > dist:
            new.append(n)
    return np.array(new)

sample_array = np.array([1, 7, 15, 16, 18, 19, 26, 33])
print(rem_dist(sample_array, 1))
print(rem_dist(sample_array, 3))
[ 1  7 15 18 26 33]
[ 1  7 15 19 26 33]

一种明显的加速方法是用Numba或Cython将其编译成机器代码,从而获得几乎快100倍的代码。

import numba

@numba.njit
def rem_dist2(arr, dist):
    elements = arr
    results = np.empty(len(arr), dtype=numba.boolean)
    results[0] = True
    i = 0
    j = 1
    while j < len(elements):
        if (elements[j] - elements[i]) <= dist:
            results[j] = False
        else:
            results[j] = True
            i = j
        j += 1

    return_val = arr[results]
    return return_val

对于10,000个元素

time rem_dist per call = 0.0028135049999999996
time rem_dist2 per call = 2.32650000000012e-05

相关问题