# Compilation for the most frequent types.
# Please pick the right ones so to speed up the compilation time.
@nb.njit(['(uint8[:,::1],uint8)', '(int32[:,::1],int32)', '(int64[:,::1],int64)', '(float64[:,::1],float64)'], parallel=True)
def select(array, value):
n = array.shape[0]
mask = np.empty(n, dtype=np.bool_)
for i in nb.prange(n):
mask[i] = array[i, 0] == value or array[i, 1] == value
return mask
x = array[select(array, value)]
1条答案
按热度按时间brqmpdu11#
np.where
经过高度优化,我怀疑是否有人能编写出比上一个Numpy版本中实现的代码更快的代码(免责声明:也就是说,这里的主要问题不是np.where
,而是创建一个临时布尔数组的条件。不幸的是,这是在Numpy中实现这一点的方法,只要你在相同的输入布局下只使用Numpy,就没有什么可做的。解释为什么它不是非常有效的一个原因是输入数据布局是低效的。实际上,假设
array
使用默认行主排序连续存储在内存中,array[:,0] == value
将在内存中读取数组的每3个项中的1个项。(即高速缓存行、预取等),浪费了2/3的内存带宽。实际上,输出布尔数组也需要被写入,并且由于页面错误,填充新创建的数组会有点慢。请注意,array[:,1] == value
肯定会从RAM重新加载数据(不能适合大多数CPU缓存)。RAM很慢并且与CPU和缓存的计算速度相比越来越慢。这个问题称为“* 内存墙 *",在几十年前就已经观察到了,而且预计不会在短期内得到修复。另外请注意,逻辑“或”也会创建一个从/读取/写入的新数组。更好的数据布局是在内存中连续的(3, 50000000)
转置数组(注意np.transpose
不会产生连续数组)。解释性能问题的另一个原因是Numpy往往没有针对在非常小的轴上运行进行优化。
一个主要的解决方案是尽可能以转置的方式创建输入。另一个解决方案是编写Numba或Cython代码。下面是非转置输入的实现:
请注意,我使用了并行实现,因为
or
运算符在Numba中不是最佳的(唯一的解决方案似乎是使用本机代码或Cython),也因为RAM在一些平台上不能完全饱和于一个线程,如计算服务器。还请注意,对于select
的结果,使用array[np.where(select(array, value))[0]]
可以更快。实际上,如果结果是随机的或非常小,则np.where
可以更快,因为它对布尔索引不执行的这些情况具有特殊优化。注意,np.where
在Numba函数的上下文中不是特别优化的,因为Numba使用其自己的Numpy函数的实现,并且它们有时对于大数组没有那么多优化。更快的实现包括并行地创建x
,但是这对于Numba来说不是微不足道的,因为输出项的数量事先不知道,并且线程必须知道将数据写入何处,更不用说Numpy在顺序执行时已经相当快了,只要输出是 * 可预测的 *。