使用条件(numpy.where)时,numpy数组的索引是否更快?

im9ewurl  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(119)

我有一个巨大的numpy数组,其形状为(50000000,3),我使用:

x = array[np.where((array[:,0] == value) | (array[:,1] == value))]

来得到我想要的数组的一部分。但是这种方法似乎很慢。有没有更有效的方法来执行与numpy相同的任务?

brqmpdu1

brqmpdu11#

np.where经过高度优化,我怀疑是否有人能编写出比上一个Numpy版本中实现的代码更快的代码(免责声明:也就是说,这里的主要问题不是np.where,而是创建一个临时布尔数组的条件。不幸的是,这是在Numpy中实现这一点的方法,只要你在相同的输入布局下只使用Numpy,就没有什么可做的。

解释为什么它不是非常有效的一个原因是输入数据布局是低效的。实际上,假设array使用默认行主排序连续存储在内存中,array[:,0] == value将在内存中读取数组的每3个项中的1个项。(即高速缓存行、预取等),浪费了2/3的内存带宽。实际上,输出布尔数组也需要被写入,并且由于页面错误,填充新创建的数组会有点慢。请注意,array[:,1] == value肯定会从RAM重新加载数据(不能适合大多数CPU缓存)。RAM很慢并且与CPU和缓存的计算速度相比越来越慢。这个问题称为“* 内存墙 *",在几十年前就已经观察到了,而且预计不会在短期内得到修复。另外请注意,逻辑“或”也会创建一个从/读取/写入的新数组。更好的数据布局是在内存中连续的(3, 50000000)转置数组(注意np.transpose不会产生连续数组)。
解释性能问题的另一个原因是Numpy往往没有针对在非常小的轴上运行进行优化
一个主要的解决方案是尽可能以转置的方式创建输入。另一个解决方案是编写Numba或Cython代码。下面是非转置输入的实现:

# Compilation for the most frequent types. 
# Please pick the right ones so to speed up the compilation time. 
@nb.njit(['(uint8[:,::1],uint8)', '(int32[:,::1],int32)', '(int64[:,::1],int64)', '(float64[:,::1],float64)'], parallel=True)
def select(array, value):
    n = array.shape[0]
    mask = np.empty(n, dtype=np.bool_)
    for i in nb.prange(n):
        mask[i] = array[i, 0] == value or array[i, 1] == value
    return mask

x = array[select(array, value)]

请注意,我使用了并行实现,因为or运算符在Numba中不是最佳的(唯一的解决方案似乎是使用本机代码或Cython),也因为RAM在一些平台上不能完全饱和于一个线程,如计算服务器。还请注意,对于select的结果,使用array[np.where(select(array, value))[0]]可以更快。实际上,如果结果是随机的或非常小,则np.where可以更快,因为它对布尔索引不执行的这些情况具有特殊优化。注意,np.where在Numba函数的上下文中不是特别优化的,因为Numba使用其自己的Numpy函数的实现,并且它们有时对于大数组没有那么多优化。更快的实现包括并行地创建x,但是这对于Numba来说不是微不足道的,因为输出项的数量事先不知道,并且线程必须知道将数据写入何处,更不用说Numpy在顺序执行时已经相当快了,只要输出是 * 可预测的 *。

相关问题