我正在看文档和一些关于torch.topk
的讨论。它说,如果sorted=False
,则无法保证返回的订单。sorted=False
的用法是什么?为什么它有这个参数?如果我们设置sorted=False
,返回顺序的规则是什么?
我有下面的例子:
>>> x = torch.rand(8)
>>> x
tensor([0.8618, 0.0271, 0.5122, 0.2415, 0.4987, 0.4552, 0.2030, 0.7258])
>>> y = torch.topk(x, 3, sorted=False)
>>> y
torch.return_types.topk(
values=tensor([0.7258, 0.8618, 0.5122]),
indices=tensor([7, 0, 2]))
>>> y = torch.topk(x, 3)
>>> y
torch.return_types.topk(
values=tensor([0.8618, 0.7258, 0.5122]),
indices=tensor([0, 7, 2]))
1条答案
按热度按时间nwlqm0z11#
从PyTorch源代码中可以看出,当
k
足够小(k * 64 <= n
)时,C++实现将运行std::partial_sort
,无论是否需要sorted
,因此结果将始终排序。否则,就像您的示例中的情况一样(
k * 64
= 192> n
= 8),实现将首先运行“通常”使用Introselect算法的std::nth_element
(基于堆的gcc实现?),如果是sorted=True
,则在top-k
位置运行额外的std::sort
。最后,我谨指出,
sorted=False
的用法是当你只对最前面的-k
值感兴趣而不是它们的顺序时。std::sort
(如果是k * 64 > n
),这样效率会更高。sorted=False
返回的顺序应该是不可预测的,这取决于底层的std::nth_element
算法。