python-3.x 将NumPy中的两个子矩阵相乘而不复制矩阵

byqmnocz  于 2022-12-20  发布在  Python
关注(0)|答案(1)|浏览(98)

我有两个矩阵AB,我想把A的一个子矩阵(由它的一些行定义)和B的一个子矩阵(由它的一些列定义)相乘,如下例所示:

import numpy as np

# create two matrices
A = np.ones((5, 3))
B = np.ones((3, 5))

# define sub-matrices
rows_idx = [0, 2, 3]
cols_idx = [1, 2, 4]

# print sub-matrices
print(A[rows_idx])
print(B[:, cols_idx])

我知道这可以直接通过

A[rows_idx, :] @ B[:, cols_idx]

然而,后者的缺点是复制A的行和B的列,因为rows_idxcols_idx是列表,如这里所提到的,这在时间和存储器方面效率较低。
此外,我没有创建Python函数来执行乘法,因为Python循环速度很慢。Numpy数组操作要快得多。Here是一些计时比较。
有没有不需要复制就可以计算A[rows_idx, :] @ B[:, cols_idx]的方法?

bvjveswy

bvjveswy1#

从技术上讲,您可以通过为每个具有相同数据的原始矩阵声明一个视图来避免复制子矩阵。
这些视图可用于进一步计算。
但是,这需要您至少遵守2个时刻:
1.安排/同意dtype:原文及其观点
1.请注意,对视图所做的任何变更都会反映在原始复本中
ndarray * 的base属性使区分数组是视图还是副本变得容易。视图的base属性返回原始数组,而对于副本则返回None。如何区分数组是视图还是副本
下面是它可能的样子:

# create two matrices
a = np.ones((5, 3), dtype=np.int8)
b = np.ones((3, 5), dtype=np.int8)

# declare views
a_view = a.view(dtype=np.int8)
b_view = b.view(dtype=np.int8)

# indices for sub-views
rows_idx = [0, 2, 3]
cols_idx = [1, 2, 4]

mult_res = a_view[rows_idx, :] @ b_view[:, cols_idx]
print(mult_res)
[[3 3 3]
 [3 3 3]
 [3 3 3]]

相关问题