如何在索引处添加pytorchTensor?

bt1cpqcv  于 2022-11-29  发布在  其他
关注(0)|答案(2)|浏览(141)

我不得不承认,我对scatter* 和index* 操作有点困惑-我不确定它们中的任何一个是否能准确地完成我所寻找的任务,这非常简单:
给定某个二维Tensor

z = tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.]])

以及二维索引的列表(或Tensor?):

inds = tensor([[0, 0],
               [1, 1],
               [1, 2]])

我想在这些索引处向z添加一个标量(并且高效地完成):

znew = z.something_add(inds, 3)
->
znew = tensor([[4., 1., 1., 1.],
               [1., 4., 4., 1.],
               [1., 1., 1., 1.]])

如果有必要的话,我可以把这个标量变成任何形状的Tensor(其中所有元素= 3),但我宁愿不......

pkwftd7m

pkwftd7m1#

您必须提供两个清单来编制索引。第一个清单包含列位置,第二个清单包含栏位置。在您的范例中,它会是:

z[[0, 1, 1], [0, 1, 2]] += 3

Torch 。Tensor索引遵循Numpy。https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing有关详细信息,请参阅www.example.com。

plupiseo

plupiseo2#

这段程式码可以达成您的目的:

z_new = z.clone() # copy the tensor
z_new[inds[:, 0], inds[:, 1]] += 3 # modify selected indices of new tensor

在PyTorch中,可以使用另一个Tensor来索引Tensor的每个轴。

相关问题