假设我有一个Tensor,使用torch.topk函数,我得到Tensor的最大k个元素及其索引。如以下代码
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
现在假设我想把上面的最大k个元素设置为0,但是保持它们在原始Tensor中的相同位置。我该怎么做呢?
如果k =3,则新Tensor应该看起来像:
Tensor([ 1.,2.,0.,0.,0.])
基本上,我如何使用这些max k元素的索引(返回topk()torch函数)将这些位置的原始值清零(将它们的值设置为0)?
最好是一个 Torch 方法的建议,做什么我问。如果不是,这将是最好的解决方案是尽可能有效的。
提前感谢你的帮助:)
1条答案
按热度按时间iyr7buue1#
torch.topk
函数会传回所提供维度上前k个元素的索引。您可以使用torch.scatter
执行重新指派作业: