假设我有来自Pytorch或Keras预测的概率,结果是softmax函数
from scipy.special import softmax
probs = softmax(np.random.randn(20,10),1) # 20 instances and 10 class probabilities
probs
我想从这个numpy数组中找到前5个索引,我所要做的就是在结果上运行一个循环,类似于:
for index in top_5_indices:
if index in result:
print('Found')
如果我的结果在前5名,我会得到。Pytorch
有top-k
功能,我看过numpy.argpartition
,但我不知道如何做到这一点?
3条答案
按热度按时间4xrmg8kj1#
numpy中的argpartition(a,k)函数将输入数组a的索引重新排列在第k个最小元素周围,这样所有较小元素的索引都在左边结束,所有较大元素的索引都在右边结束。时间复杂度为O(n)。
所以你可以得到5个最大元素的指数,如下所示:
nuypyhwy2#
稍微贵一点,但是
argsort
可以:如果我们谈论的是pytorch:
n6lpvg4x3#
现有的答案是正确的,但我想对它们进行扩展,以提供一个自包含函数,其行为与纯
numpy
的torch.topk
完全相同。下面是这个函数(我已经在内联中包含了指令):
要验证正确性,您可以对torch进行测试:
np.take_along_axis
是用于访问更高维度中的原始值的recommended to be used withnp.argpartition
。np.argpartition
比np.argsort
快,因为它不对整个数组排序。This answer声称它使用O(n)
而不是'O(n log