pytorch 如何从numpy数组中获取前K个值的索引

f1tvaqid  于 2023-02-08  发布在  其他
关注(0)|答案(3)|浏览(258)

假设我有来自Pytorch或Keras预测的概率,结果是softmax函数

from scipy.special import softmax
probs = softmax(np.random.randn(20,10),1) # 20 instances and 10 class probabilities
probs

我想从这个numpy数组中找到前5个索引,我所要做的就是在结果上运行一个循环,类似于:

for index in top_5_indices:
    if index in result:
        print('Found')

如果我的结果在前5名,我会得到。
Pytorchtop-k功能,我看过numpy.argpartition,但我不知道如何做到这一点?

4xrmg8kj

4xrmg8kj1#

numpy中的argpartition(a,k)函数将输入数组a的索引重新排列在第k个最小元素周围,这样所有较小元素的索引都在左边结束,所有较大元素的索引都在右边结束。时间复杂度为O(n)。
所以你可以得到5个最大元素的指数,如下所示:

np.argpartition(probs,-5)[-5:]
nuypyhwy

nuypyhwy2#

稍微贵一点,但是argsort可以:

idx = np.argsort(probs, axis=1)[:,-5:]

如果我们谈论的是pytorch:

probs = torch.from_numpy(softmax(np.random.randn(20,10),1))

values, idx = torch.topk(probs, k=5, axis=-1)
n6lpvg4x

n6lpvg4x3#

现有的答案是正确的,但我想对它们进行扩展,以提供一个自包含函数,其行为与纯numpytorch.topk完全相同。
下面是这个函数(我已经在内联中包含了指令):

def topk(array, k, axis=-1, sorted=True):
    # Use np.argpartition is faster than np.argsort, but do not return the values in order
    # We use array.take because you can specify the axis
    partitioned_ind = (
        np.argpartition(array, -k, axis=axis)
        .take(indices=range(-k, 0), axis=axis)
    )
    # We use the newly selected indices to find the score of the top-k values
    partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
    
    if sorted:
        # Since our top-k indices are not correctly ordered, we can sort them with argsort
        # only if sorted=True (otherwise we keep it in an arbitrary order)
        sorted_trunc_ind = np.flip(
            np.argsort(partitioned_scores, axis=axis), axis=axis
        )
        
        # We again use np.take_along_axis as we have an array of indices that we use to
        # decide which values to select
        ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
        scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
    else:
        ind = partitioned_ind
        scores = partition_scores
    
    return scores, ind

要验证正确性,您可以对torch进行测试:

import torch
import numpy as np

x = np.random.randn(50, 50, 10, 10)

axis = 2  # Change this to any axis and it'll be fine

val_np, ind_np = topk(x, k=10, axis=axis)

val_pt, ind_pt = torch.topk(torch.tensor(x), k=10, dim=axis)

print("Values are same:", np.all(val_np == val_pt.numpy()))
print("Indices are same:", np.all(ind_np == ind_pt.numpy()))

相关问题