Numpy:按2D遮罩过滤2D索引列表

yjghlzjz  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(134)

我有一个2D点data(x,y,other_properties)的numpy数组,其中x和y是整数像素坐标。另外,我有一个二进制2D分割掩码mask。我想过滤列表,只获得掩码为1/true的点。
我想做的事情是这样的:

valid_indices = np.argwhere(mask_2D)

然后根据有效的索引过滤数据,我想使用numpy加速来完成。

# the data representing x and y coordinates
data = np.arange(10).reshape((5, 2))
data = np.concatenate((data, data), axis=0)
print(f"data: {data}")

# Let's say we already obtained the indices from the segmentation mask
valid_indices = [(6, 7), (0, 1)]
print(f"valid_indices: {valid_indices}")

filtered = []
for point in data:
    if tuple(point) in valid_indices:
        filtered.append(tuple(point))
filtered = np.array(filtered)
print(f"filtered: {filtered}")

输出:

data:
[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]
 [0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]
valid_indices:
[(6, 7), (0, 1)]
filtered:
[[0 1]
 [6 7]
 [0 1]
 [6 7]]

Process finished with exit code 0


有没有方法可以使用numpy获得上述行为?解决方案也可以直接使用二进制2D分割掩码。如果没有,您有什么建议如何加快这一过程?谢谢!

vof42yt1

vof42yt11#

你可以用广播比较来做,

data[(data == np.array(valid_indices)[:,None]).all(-1).any(0)]

相关问题