我有一个二维Tensor和一个一维Tensor:
import torch
torch.manual_seed(0)
out = torch.randn((16,2))
target = torch.tensor([0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0])
对于out
的每一行,我想选择由target
索引的相应列。因此,我的输出将是(16,1)
Tensor。我尝试了下面提到的解决方案:
https://stackoverflow.com/a/58937071
但我得到:
Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-7-50d103c3b56c>", line 1, in <cell line: 1>
out.gather(1, target)
RuntimeError: Index tensor must have the same number of dimensions as input tensor
你能帮忙吗?
1条答案
按热度按时间72qzrwbm1#
为了应用
torch.gather
,这两个Tensor必须具有相同的维数。因此,您应该在target
* 的最后一个位置 * 上解压缩一个额外的维数: