我尝试使用一个jitted numba函数将大小为$(N,N)$的矩阵A乘以大小为$(N,M,O)$的矩阵B(也就是将B在O维的所有"页"左乘A)。
我想出了这个解决方案:
@njit
def fast_expectation(Pi, X):
res = np.empty_like(X)
for i in range(Pi.shape[0]):
for j in range(X.shape[1]):
for k in range(X.shape[2]):
res[i,j,k] = np.dot(Pi[i,:], X[:,j,k])
return res
然而,这会返回一个警告NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 1d, C), array(float64, 1d, A))
。你知道我如何用一个numba兼容的函数快速地执行这个操作吗?
我试着运行前面的代码,交换矩阵B的数组(把它变成一个(N,M,O)矩阵)。
编辑:
我还尝试了以下代码:
@njit
def multiply_ith_dimension(Pi, i, X):
"""If Pi is a matrix, multiply Pi times the ith dimension of X and return"""
X = np.swapaxes(X, 0, i)
shape = X.shape
X = X.reshape(shape[0], -1)
# iterate forward using Pi
X = Pi @ X
# reverse steps
X = X.reshape(Pi.shape[0], *shape[1:])
return np.swapaxes(X, 0, i)
这也给了我一个错误
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
reshape() supports contiguous array only
...
<source elided>
shape = X.shape
X = X.reshape(shape[0], -1)
^
2条答案
按热度按时间7gcisfzg1#
但是,这将返回警告
NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 1d, C), array(float64, 1d, A))
这是因为您使用
X[:,j,k]
访问了row-majorNumpy数组。最后一个维是连续的,但第一个维不是。您可以使用换位来修复这个问题。也就是说,Numpy换位不会创建新数组。相反,它会创建一个具有跨距的view。您可以使用**np.ascontiguousarray
**强制创建连续数组。或者,您可以直接显式复制数组。例如:arr.T.copy()
.另一种解决方案是不使用
np.dot
,而是使用普通循环。循环在Numba中通常非常有效,而调用Numpy函数会引入一些开销(主要是由于分配、临时数组或隐式转换)。也会产生一个错误
[...] reshape() supports contiguous array only
这是Numba目前的一个限制。Numpy可以正确地做到这一点(至少我在我的机器上没有Numpy 1.22.4的问题)。这当然是因为Numba希望结果总是一个视图,而Numpy实际上可以在某些情况下返回一个副本(通常像这样)。
请注意,您的第一个Numba代码效率很低,因为它是顺序的,而且很幼稚(没有平铺,
A @ B
通常非常高效,因为它使用了BLAS库。(大多数平台上的默认实现)或"英特尔MKL"库几十年来一直由该领域的Maven进行高度优化。如果可能,它们使用多线程,并使用复杂得多优化代码(通常在C中手动使用SIMD指令)。camsedfj2#
使用
np.einsum
(或者更快的opt_einsum
库)可以很容易地做到这一点。result = np.einsum('ab,bcd->acd', A, B, optimize=True)
注意,
einsum
在幕后使用np.dot
(基于BLAS),因此可能是最快的方法。