我想通过给出源Tensor和索引Tensor来选择2dTensor的特定行。输入:a = torch.FloatTensor([[1,1,1],[2,2,2]],[[9,9,9],[5,5,5]])B = torch.IntTensor([1,0])有解决办法吗?预期结果:2,2,2,[9,9,9]]
ltqd579y1#
我建议使用聚集方法和2D索引
a = torch.FloatTensor([[[1,1,1],[2,2,2]],[[9,9,9],[5,5,5]]]) b = torch.LongTensor([1,0]) R = a.shape[0] C = a.shape[2] idx = b.unsqueeze(dim=1).repeat(1, C).view(R, 1, C) torch.gather(a, 1, idx)
字符串
pjngdqdw2#
这是一个很好的解决方案。
out = a[torch.arange(a.size(0)), b]
2条答案
按热度按时间ltqd579y1#
我建议使用聚集方法和2D索引
字符串
pjngdqdw2#
这是一个很好的解决方案。
字符串