pytorch torch.topk中sorted=False的用法是什么?

aamkag61  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(176)

我正在看文档和一些关于torch.topk的讨论。它说,如果sorted=False,则无法保证返回的订单。sorted=False的用法是什么?为什么它有这个参数?如果我们设置sorted=False,返回顺序的规则是什么?
我有下面的例子:

>>> x = torch.rand(8)
>>> x
tensor([0.8618, 0.0271, 0.5122, 0.2415, 0.4987, 0.4552, 0.2030, 0.7258])
>>> y = torch.topk(x, 3, sorted=False)
>>> y
torch.return_types.topk(
values=tensor([0.7258, 0.8618, 0.5122]),
indices=tensor([7, 0, 2]))
>>> y = torch.topk(x, 3)
>>> y
torch.return_types.topk(
values=tensor([0.8618, 0.7258, 0.5122]),
indices=tensor([0, 7, 2]))
nwlqm0z1

nwlqm0z11#

从PyTorch源代码中可以看出,当k足够小k * 64 <= n)时,C++实现将运行std::partial_sort,无论是否需要sorted,因此结果将始终排序。

否则,就像您的示例中的情况一样(k * 64 = 192 > n = 8),实现将首先运行“通常”使用Introselect算法的std::nth_element(基于堆的gcc实现?),如果是sorted=True,则在top-k位置运行额外的std::sort

最后,我谨指出,

  • sorted=False的用法是当你只对最前面的-k值感兴趣而不是它们的顺序时。
  • 使用这个参数,您可以保存一个std::sort(如果是k * 64 > n),这样效率会更高。
  • AFAIK,使用sorted=False返回的顺序应该是不可预测的,这取决于底层的std::nth_element算法。

相关问题