numpy 不需要额外内存分配的乘积和和

cqoc49vn  于 2023-04-12  发布在  其他
关注(0)|答案(1)|浏览(162)

有没有办法在不分配额外内存的情况下将两个数组相乘并沿着一个轴(或多个轴)求和?
在本例中:

import numpy as np

A = np.random.random((10, 10, 10))
B = np.random.random((10, 10, 10))

C = np.sum(A[:, None, :, :, None] * B[None, :, None, :, :], axis=(-1,-2))

在计算C时,创建了一个大小为10x10x10x10x10的中间矩阵,但它会立即被折叠。在numpy中有没有方法可以避免这种情况?

g6ll5ycj

g6ll5ycj1#

这看起来像是第二个数组转置的点积:

C = A @ B.T
  • 注:原质询的行动是:C = np.sum(A[:, None, :] * B[None, :, :], axis=-1) .*

快速检查:

C1 = np.sum(A[:, None, :] * B[None, :, :], axis=-1)
C2 = A @ B.T

assert np.allclose(C1, C2)

你可以用einsum来概括:

np.einsum('ikl,jlm->ijk', A, B)

快速检查:

A = np.random.random((2, 3, 4))
B = np.random.random((2, 4, 5))

#             i    j   k  l    m         i   j    k   l  m
C1 = np.sum(A[:, None, :, :, None] * B[None, :, None, :, :], axis=(-1,-2))
C2 = np.einsum('ikl,jlm->ijk', A, B)

assert np.allclose(C1, C2)

相关问题