python 将(N,N)矩阵沿着O维乘以(N,M,O)矩阵,计算公式为Numba

sqougxex  于 2023-01-19  发布在  Python
关注(0)|答案(2)|浏览(408)

我尝试使用一个jitted numba函数将大小为$(N,N)$的矩阵A乘以大小为$(N,M,O)$的矩阵B(也就是将B在O维的所有"页"左乘A)。
我想出了这个解决方案:

@njit
def fast_expectation(Pi, X):
    
    res = np.empty_like(X)
    
    for i in range(Pi.shape[0]):
        for j in range(X.shape[1]):
            for k in range(X.shape[2]):
                res[i,j,k] = np.dot(Pi[i,:], X[:,j,k])
                            
    return res

然而,这会返回一个警告NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 1d, C), array(float64, 1d, A))。你知道我如何用一个numba兼容的函数快速地执行这个操作吗?
我试着运行前面的代码,交换矩阵B的数组(把它变成一个(N,M,O)矩阵)。
编辑:
我还尝试了以下代码:

@njit
def multiply_ith_dimension(Pi, i, X):
    """If Pi is a matrix, multiply Pi times the ith dimension of X and return"""
    X = np.swapaxes(X, 0, i)
    shape = X.shape
    X = X.reshape(shape[0], -1)

    # iterate forward using Pi
    X = Pi @ X

    # reverse steps
    X = X.reshape(Pi.shape[0], *shape[1:])
    return np.swapaxes(X, 0, i)

这也给了我一个错误

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
reshape() supports contiguous array only
...
    <source elided>
    shape = X.shape
    X = X.reshape(shape[0], -1)
    ^
7gcisfzg

7gcisfzg1#

但是,这将返回警告NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 1d, C), array(float64, 1d, A))
这是因为您使用X[:,j,k]访问了row-majorNumpy数组。最后一个维是连续的,但第一个维不是。您可以使用换位来修复这个问题。也就是说,Numpy换位不会创建新数组。相反,它会创建一个具有跨距的view。您可以使用**np.ascontiguousarray**强制创建连续数组。或者,您可以直接显式复制数组。例如:arr.T.copy().
另一种解决方案是不使用np.dot,而是使用普通循环。循环在Numba中通常非常有效,而调用Numpy函数会引入一些开销(主要是由于分配、临时数组或隐式转换)。
也会产生一个错误[...] reshape() supports contiguous array only
这是Numba目前的一个限制。Numpy可以正确地做到这一点(至少我在我的机器上没有Numpy 1.22.4的问题)。这当然是因为Numba希望结果总是一个视图,而Numpy实际上可以在某些情况下返回一个副本(通常像这样)。
请注意,您的第一个Numba代码效率很低,因为它是顺序的,而且很幼稚(没有平铺,A @ B通常非常高效,因为它使用了BLAS库。(大多数平台上的默认实现)或"英特尔MKL"库几十年来一直由该领域的Maven进行高度优化。如果可能,它们使用多线程,并使用复杂得多优化代码(通常在C中手动使用SIMD指令)。

camsedfj

camsedfj2#

使用np.einsum(或者更快的opt_einsum库)可以很容易地做到这一点。
result = np.einsum('ab,bcd->acd', A, B, optimize=True)
注意,einsum在幕后使用np.dot(基于BLAS),因此可能是最快的方法。

相关问题