pytorch 从2DTensor中,通过每行选择1列返回1DTensor

ivqmmu1c  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(289)

我有一个二维Tensor和一个一维Tensor:

import torch
torch.manual_seed(0)

out = torch.randn((16,2))
target = torch.tensor([0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0])

对于out的每一行,我想选择由target索引的相应列。因此,我的输出将是(16,1)Tensor。我尝试了下面提到的解决方案:
https://stackoverflow.com/a/58937071
但我得到:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-50d103c3b56c>", line 1, in <cell line: 1>
    out.gather(1, target)
RuntimeError: Index tensor must have the same number of dimensions as input tensor

你能帮忙吗?

72qzrwbm

72qzrwbm1#

为了应用torch.gather,这两个Tensor必须具有相同的维数。因此,您应该在target * 的最后一个位置 * 上解压缩一个额外的维数:

>>> out.gather(1, target[:,None])
tensor([[-1.1258],
        [-0.4339],
        [ 0.6920],
        [-2.1152],
        [ 0.3223],
        [ 0.3500],
        [ 1.2377],
        [ 1.1168],
        [-1.6959],
        [ 0.7935],
        [ 0.5988],
        [-0.3414],
        [ 0.7502],
        [ 0.1835],
        [ 1.5863],
        [ 0.9463]])

相关问题