pytorch 如何用3D数组索引2D数组?

aydmsdu9  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(209)

今天,我遇到了这样一个问题:
TensorA是形状为(1, 4, 4)的分割掩模,其值为0或1。
TensorB是由torch.eye(2)创建的对角数组。
我的问题是,为什么我们可以用B[A]的形式的A(3D)来索引B(2D),为什么结果是(1, 4, 4, 2)形状的Tensor?
上面是我的测试示例,socure代码是从一个dicloss类中获得的:

y_true_dummy = torch.eye(num_classes)[y_true.squeeze(1)]

y_true的形状是(b, h, w)num_classes等于c
顺便问一下,为什么我们需要函数.squeeze()
我想一些关于索引问题的解释和一些视频更赞赏。

afdcj2ne

afdcj2ne1#

如果您处理一个较小的示例,就可以理解这个问题:

A = torch.randint(2, (4,))
B = torch.eye(2)

>>> A
# tensor([1, 0, 1, 1])

>>> B[A].shape
# (4, 2)

>>> B[A]
# tensor([[0., 1.],
#         [1., 0.],
#         [0., 1.],
#         [0., 1.]])

[1, 0][0, 1]是2x2单位矩阵B的第一行和第二行。因此,使用形状为(4,)的一维数组A作为索引是选择B的4个“行”/沿着轴0选择B的4个元素。B[A]基本上是[B[1], B[1], B[0], B[1]]
因此,当A是形状为(1, 4, 4)的三维数组时,B[A]意味着**选择B的(1,4,4)行。**由于B中的每一行都有2个元素(2列),因此输出为(1,4,4,2)。
B是一个2x2单位矩阵,有2行。从这两行中选出16行,得到一个(16,2)矩阵,然后将其整形得到(1,4,4,2)Tensor。实际上,你可以很容易地检查:

A = torch.randint(2, (4, 4))
A_flat = A.reshape(-1)
B = torch.eye(2)

>>> torch.allclose(B[A], B[A_flat].reshape(1, 4, 4, -1)])
# True

这也不是PyTorch特有的现象,你可以在NumPy中观察到同样的索引规则,它与torch保持着紧密的兼容性。

相关问题