使非对称Tensor的对角线完全位于pytorch的右下角

ecbunoof  于 2023-03-12  发布在  其他
关注(0)|答案(1)|浏览(119)

我有一个形状为(5 * n, n)的Tensor,我基本上想从第一列的前5行中提取前5个元素,然后移动1列,提取第二列的下5行,然后移动等。这有点像 Torch 。对角线,但适用于非对称Tensor,可以假设Tensor有适当的维数,每次都能工作。
例如,如果我的输入Tensor是:

>>> t = torch.arange(45).reshape(15, 3)
>>> t
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23],
        [24, 25, 26],
        [27, 28, 29],
        [30, 31, 32],
        [33, 34, 35],
        [36, 37, 38],
        [39, 40, 41],
        [42, 43, 44]])

我就会想办法

out = tensor([0, 3, 6, 9, 12, 16, 19, 22, 25, 28, 32, 35, 38, 41, 44])

我不希望使用循环,因为对于一些很大的输入,我会重复上千次,所以我觉得这样效率会很低。

djmepvbi

djmepvbi1#

我得到了一个使用切片的高效解决方案:

import torch
from itertools import tee

def pairwise(iterable):
    # pairwise('ABCDEFG') --> AB BC CD DE EF FG
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def generalized_diagonal(t):
    ratio = int(max(t.shape) / min(t.shape))
    indexes = ( (i, i*ratio) for i in range(min(t.shape)+1) )
    parts = [t[y0:y1, x0:x1] for (x0, y0), (x1, y1) in pairwise(indexes)]
    return torch.flatten(torch.stack(parts, dim=0))

t = torch.arange(45).reshape(15, 3)
print(generalized_diagonal(t))

输出:

tensor([ 0,  3,  6,  9, 12, 16, 19, 22, 25, 28, 32, 35, 38, 41, 44])

相关问题