pytorch 高维矩阵的子矩阵如何在保证子矩阵相对位置的情况下进行合并?

cgyqldqp  于 2023-02-22  发布在  其他
关注(0)|答案(1)|浏览(130)

如果我有一个形状为[z,d,d]的Tensorx,它表示一系列图像帧,就像视频数据一样。令pz = z**0.5,令x = x. view(pz,pz,d,d]。那么我们可以得到一个网格大小为pz * pz的图像网格,每个图像的形状为[d,d]。现在,我想得到一个形状为[1,1,p * d,p * d]的矩阵或Tensor。并且必须保证所有元素与所有原始图像保持相同的插入位置。
例如:

x =    [[[ 0,  1],
             [ 2,  3]],

            [[ 4,  5],
             [ 6,  7]],

            [[ 8,  9],
             [10, 11]],
    
            [[12, 13],
             [14, 15]]]

表示一系列形状为[2,2]且z = 4的图像,我想得到Tensor,如:

tensor([[ 0,  1,  4,  5],
        [ 2,  3,  6,  7],
        [ 8,  9, 12, 13],
        [10, 11, 14, 15]])

我可以使用x = x. view(1,1,4,4)来得到一个具有相同形状的视图,但它如下所示:

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]]]])

但我不想要。
还有更多,如果x有更多的维数呢?就像[b,c,z,d,d]。怎么处理这个?
任何建议都将是有益的。
我有一个关于三维情况的解决方案。如果x. shape =[z,d,d],那么下面的代码将工作。但不工作的高维Tensor。嵌套循环将是好的,但太沉重。我的解决方案为三维情况:

d = 2
    z = 4
    b, c = 1, 1
    x = torch.arange(z*d*d).view(z, d, d)
    # x = torch.tensor([[[ 1,  2],
    #          [ 4,  6]],
    #
    #         [[ 8, 10],
    #          [12, 14]],
    #
    #         [[16, 18],
    #          [20, 22]],
    #
    #         [[24, 26],
    #          [28, 30]],
    #
    #         [[32, 34],
    #          [36, 38]],
    #
    #         [[40, 42],
    #          [44, 46]],
    #
    #         [[48, 50],
    #          [52, 54]],
    #
    #         [[56, 58],
    #          [60, 62]],
    #
    #         [[64, 66],
    #          [68, 70]]])
    # make z-index planes to a grid layout
    grid_side_len = int(z**0.5)
    grid_x = x.view(grid_side_len, grid_side_len, d, d)
    # for all rows of crops , horizontally stack them togather
    plane = []
    for i in range(grid_x.shape[0]):
        cat_crops = torch.hstack([crop for crop in grid_x[i]])
        plane.append(cat_crops)

    plane = torch.vstack([p for p in plane])
    print("3D crop to 2D crop plane:")
    print(x)
    print(plane)
    print(plane.shape)

    print("2D crop plane to 3D crop:")
    # group all rows
    split = torch.chunk(plane, plane.shape[1]//d, dim=0)
    spat_flatten = torch.cat([torch.cat(torch.chunk(p, p.shape[1]//d, dim=1), dim=0) for p in     split], dim=0)
    crops = [t[None,:,:] for t in torch.chunk(spat_flatten, spat_flatten.shape[0]//d, dim=0)]
    spat_crops = torch.cat(crops, dim=0)
    print(spat_crops)
    print(spat_crops.shape)
e4yzc0pl

e4yzc0pl1#

这是一个可以用torch.transposetorch.reshape运算的组合来解决的运算。

>>> x = torch.arange(16).view(4,2,2)

1.首先转置Tensor,使你要校对的维度“垂直”,这可以用x.transpose(dim0=1, dim1=2)来完成。不过,我建议用负维度来代替:

>>> x.transpose(-1,-2)
tensor([[[ 0,  2],
         [ 1,  3]],

        [[ 4,  6],
         [ 5,  7]],

        [[ 8, 10],
         [ 9, 11]],

        [[12, 14],
         [13, 15]]])

1.然后重塑以整理尺寸:

>>> x.transpose(-1,-2).reshape(2,4,2)
tensor([[[ 0,  2],
         [ 1,  3],
         [ 4,  6],
         [ 5,  7]],

        [[ 8, 10],
         [ 9, 11],
         [12, 14],
         [13, 15]]])

1.然后向后翻转以恢复步骤1中的元素顺序。

>>> x.transpose(-1,-2).reshape(2,4,2).transpose(-1,-2)
tensor([[[ 0,  1,  4,  5],
         [ 2,  3,  6,  7]],

        [[ 8,  9, 12, 13],
         [10, 11, 14, 15]]])

1.最后,将其重塑为所需的形状:

>>> x.transpose(-1,-2).reshape(2,4,2).transpose(-1,-2).reshape(len(x),-1)
tensor([[ 0,  1,  4,  5],
        [ 2,  3,  6,  7],
        [ 8,  9, 12, 13],
        [10, 11, 14, 15]])

从那里你可以通过改变维度大小甚至扩展到更高的维度来满足你的需要,比如你所描述的[b, c, z, d, d]。如果你通过这个例子理解了这个简单的方法,你就可以解决任何类似的问题。

相关问题