python 了解PytorchTensor切片

7uzetpgm  于 2023-02-28  发布在  Python
关注(0)|答案(2)|浏览(181)

ab是两个PyTorchTensor,分别为a.shape=[A,3]b.shape=[B,3],且b的类型为long
然后我知道有几种方法可以对a进行切片。

c = a[N1:N2:jump,[0,2]] # N1<N2<A

对于N1=1、N2=4和jump=2,将返回c.shape = [2,2]
但是下面的代码应该会抛出一个错误,

c = a[b]

而是c.shape = [B,3,3]
例如,

a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])

有人能解释一下a[b]的切片工作原理吗?

vlju58qv

vlju58qv1#

基础知识

  • 当您使用a[B]时,PyTorch执行高级索引。
  • 在这种情况下,TensorB的每一行都被视为a的第一维的索引,并且返回a的对应行。
  • 由于B的形状为[B,3],这意味着b的每一行都是a的第一维中的3元素索引。因此a[b]的结果将具有形状[B,3,d],其中d是a中的列数。

例如

假设B具有以下值:

b = torch.tensor([[0,1,2], [3,4,5], [1,2,3]])
  • 那么a[B]的结果将是形状为[3,3,3]的Tensor,其中第一维对应于b的三行,第二维对应于b的每行中的三个索引。

以下是这些值的计算方法:

  • B的第一行是[0,1,2]。
  • 这意味着返回a的第一行,
  • 接着是A的第二行,然后是A的第三行。
  • 因此,结果的第一个“切片”将是:
[[a[0,0], a[0,1], a[0,2]],
 [a[1,0], a[1,1], a[1,2]],
 [a[2,0], a[2,1], a[2,2]]]

B的第二行是[3,4,5]。

  • 这意味着返回a的第四行,
  • 接着是A的第五行,
  • 然后是A的第六行。
  • 因此,结果的第二个“切片”将是:
[[a[3,0], a[3,1], a[3,2]],
 [a[4,0], a[4,1], a[4,2]],
 [a[5,0], a[5,1], a[5,2]]]

B的第三行是[1,2,3]。

  • 这意味着返回a的第二行,
  • 接着是A的第三行,
  • 然后是A的第四行。
  • 因此,结果的第三个“切片”将是:
[[a[1,0], a[1,1], a[1,2]],
 [a[2,0], a[2,1], a[2,2]],
 [a[3,0], a[3,1], a[3,2]]]

所有这些切片沿着第一维度连接以产生具有形状[3,3,3]的最终结果。

cxfofazt

cxfofazt2#

由于B是长的,torch把它当作指数仓位,如果它不是长的,上面的操作就不起作用。

In [29]: a[b]
Out[29]: 
tensor([[[-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[ 0.3707, -0.6549,  1.3003],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.7021, -1.1604, -0.8919],
         [ 0.3707, -0.6549,  1.3003],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.9325,  1.2281,  1.0513],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919]],

        [[-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-1.9443, -1.5545,  0.3944],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]]])

In [30]: a
Out[30]: 
tensor([[-0.4933,  0.8588,  1.5655],
        [-1.9443, -1.5545,  0.3944],
        [ 0.3707, -0.6549,  1.3003],
        [ 0.6938, -1.1753, -0.0484],
        [-0.0178, -0.0227,  0.3007],
        [-1.7586, -0.6923,  3.0981],
        [ 1.0726,  0.3889,  1.6468],
        [ 1.7248, -2.6932, -1.2202],
        [-0.9325,  1.2281,  1.0513],
        [-0.7021, -1.1604, -0.8919]])

In [31]: b
Out[31]: 
tensor([[ 0, -1,  0],
        [ 1, -1,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 1,  0,  0],
        [ 2,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [-1,  2,  0],
        [-1,  0,  0],
        [ 0,  0,  0],
        [ 0, -1,  0],
        [ 1, -2,  0],
        [ 0,  0,  0],
        [ 0,  0, -1],
        [ 0,  0, -1],
        [-1,  0,  0],
        [ 0,  1,  0],
        [ 0,  0,  0],
        [ 1, -1,  0]])

注意,a[b]的第一个元素是a的第一个元素,并且最后一个元素和再次的第一个元素对应于索引[0, -1, 0],并且因此由于它对于a的相关位置的每个条目进行采样,所以你得到[20, 3, 3]形状。
因此,假设b中的每个条目对应于a焊炬切片a中具有给定位置的有效索引,并且对于b的每个条目也是如此,并将所有条目连接到具有上述形状的新Tensor。如果将存在无效索引(b = torch.randn(20, 3).long() * 10),则您将得到:

----> 1 a[b]

IndexError: index 10 is out of bounds for dimension 0 with size 10

相关问题