有没有办法在不分配额外内存的情况下将两个数组相乘并沿着一个轴(或多个轴)求和?
在本例中:
import numpy as np
A = np.random.random((10, 10, 10))
B = np.random.random((10, 10, 10))
C = np.sum(A[:, None, :, :, None] * B[None, :, None, :, :], axis=(-1,-2))
在计算C时,创建了一个大小为10x10x10x10x10的中间矩阵,但它会立即被折叠。在numpy中有没有方法可以避免这种情况?
1条答案
按热度按时间g6ll5ycj1#
这看起来像是第二个数组转置的点积:
C = np.sum(A[:, None, :] * B[None, :, :], axis=-1)
.*快速检查:
你可以用
einsum
来概括:快速检查: