带向量的PyTorch切片矩阵

h79rfbju  于 2023-02-04  发布在  其他
关注(0)|答案(2)|浏览(138)

假设有一个矩阵和一个向量,如下所示:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

有没有一种方法可以将其分割为x[y],结果是:

res = [1, 6, 8]

所以基本上我取y的第一个元素和x中对应于第一行和元素列的元素。

vjhs03f7

vjhs03f71#

可以将相应的行索引指定为:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])
6rqinv9w

6rqinv9w2#

pytorch中的高级索引就像NumPy's一样工作,即索引数组在轴上一起广播,所以你可以像FBruzzesi的答案那样做。
虽然与np.take_along_axis类似,但在pytorch中,您也可以使用torch.gather来获取沿着特定轴的值:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])

相关问题