我不得不承认,我对scatter* 和index* 操作有点困惑-我不确定它们中的任何一个是否能准确地完成我所寻找的任务,这非常简单:
给定某个二维Tensor
z = tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
以及二维索引的列表(或Tensor?):
inds = tensor([[0, 0],
[1, 1],
[1, 2]])
我想在这些索引处向z添加一个标量(并且高效地完成):
znew = z.something_add(inds, 3)
->
znew = tensor([[4., 1., 1., 1.],
[1., 4., 4., 1.],
[1., 1., 1., 1.]])
如果有必要的话,我可以把这个标量变成任何形状的Tensor(其中所有元素= 3),但我宁愿不......
2条答案
按热度按时间pkwftd7m1#
您必须提供两个清单来编制索引。第一个清单包含列位置,第二个清单包含栏位置。在您的范例中,它会是:
Torch 。Tensor索引遵循Numpy。https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing有关详细信息,请参阅www.example.com。
plupiseo2#
这段程式码可以达成您的目的:
在PyTorch中,可以使用另一个Tensor来索引Tensor的每个轴。