pytorch Python中无for循环的矩阵行相邻元素局部极小快速求法

qgelzfjb  于 2023-10-20  发布在  Python
关注(0)|答案(1)|浏览(112)

假设我在python中有一个行向量,对于矩阵的一行中的每一个条目,我想在上面一行的三个邻域内找到上面一行中的最小值。
说得再清楚一点,我想找到这个:
M[i,j]=M[i,j]+min(M[i-1][j-1:j+2]
很明显,这可以使用两个嵌套的for循环来完成,一个迭代矩阵的行,另一个迭代矩阵的列。然而,这是非常非常耗时的。
我正在寻找一种方法来做到这一点,只有循环,以优化我的代码。我可以使用pytorch内置函数来加速这个过程吗?
正如我所说,我使用2 for循环实现了这一点,但它非常耗时,我正在寻找一个基于向量的解决方案,而不是使用for循环。

juzqafwq

juzqafwq1#

假设你的原始代码是这样的:

import numpy as np

M = np.random.random((20, 30))

new_M = M.copy()
for i in range(1, M.shape[0]): # First row unaffected
    for j in range(M.shape[1]):
        new_M[i, j] = M[i, j] + min(M[i - 1][max(0, j - 1) : j + 2])

一个矢量化的等价物是:

from numpy.lib.stride_tricks import sliding_window_view

new_M = M
new_M[1:] = (
    new_M[1:]
    + sliding_window_view(
        np.pad(M, ((0, 0), (1, 1)), "constant", constant_values=np.inf)[:-1, :], (1, 3)
    )
    .min(-1)
    .squeeze()
)

让我们把它分解。我们在这里做的是:

  • np.pad(M, ((0, 0), (1, 1)), "constant", constant_values=np.inf)中的np.inf填充M的左侧和右侧,因此min操作仍然在边缘上工作。
  • 使用[:-1, :]删除最后一行,这在公式中没有使用。
  • sliding_window_view(M_pad[:-1, :], (1, 3))创建一个形状为(*M_pad.shape, 1, 3)的数组,其中最后一个维度中的3个元素是用于每个min的3个元素。

最后,我们取最小值,压缩结果,使其与M的形状相同,但最后一行缺失,并将其添加到M[1:],因为第一行不受影响。

相关问题