选择3D PyTorchTensor的特定行

tjjdgumg  于 2023-11-19  发布在  其他
关注(0)|答案(2)|浏览(87)

我想通过给出源Tensor和索引Tensor来选择2dTensor的特定行。
输入:
a = torch.FloatTensor([[1,1,1],[2,2,2]],[[9,9,9],[5,5,5]])B = torch.IntTensor([1,0])
有解决办法吗?
预期结果:2,2,2,[9,9,9]]

ltqd579y

ltqd579y1#

我建议使用聚集方法和2D索引

a = torch.FloatTensor([[[1,1,1],[2,2,2]],[[9,9,9],[5,5,5]]])
b = torch.LongTensor([1,0])

R = a.shape[0]
C = a.shape[2]

idx = b.unsqueeze(dim=1).repeat(1, C).view(R, 1, C)
torch.gather(a, 1, idx)

字符串

pjngdqdw

pjngdqdw2#

这是一个很好的解决方案。

out = a[torch.arange(a.size(0)), b]

字符串

相关问题