给定一个2D numpy数组,我想用每一行的最大值的列索引构造一个数组,到目前为止,arr.argmax(1)
工作得很好,然而,对于我的特定情况,对于某些行,可能有2列或更多列包含最大值,在这种情况下,我想随机选择一个列索引(而不是像.argmax(1)
那样选择第一个索引)。
例如,对于以下arr
:
arr = np.array([
[0, 1, 0],
[1, 1, 0],
[2, 1, 3],
[3, 2, 2]
])
可能有两种结果:array([1, 0, 2, 0])
和array([1, 1, 2, 0])
,每个都以1/2的概率选择。
我的代码使用列表解析返回预期的输出:
idx = np.arange(arr.shape[1])
ans = [np.random.choice(idx[ix]) for ix in arr == arr.max(1, keepdims=True)]
但是我在寻找一个优化的numpy解决方案,换句话说,我如何用numpy方法替换列表解析,使代码适用于更大的数组?
2条答案
按热度按时间zsohkypk1#
按如下方式使用
scipy.stats.rankdata
和apply_along_axis
。它返回[1 0 2 0]或[1 1 2 0]。
其主要思想是
rankdata
计算每一行中每一个值的秩,最大值将为1。func
随机选择一个对应值为1的索引。最后,apply_along_axis
将func
应用于arr
的每一行。mwg9r5ms2#
在得到一些建议后,我下线了,当我们把标记行最大值的布尔数组乘以一个相同形状的随机数组时,最大值的随机化是可能的,然后剩下的就是一个简单的
argmax(1)
调用。timeit测试显示,对于
(507_563, 12)
形状的数据,这段代码在我的机器上运行了大约172 ms,而问题中的循环运行了11秒,因此速度大约快了63倍。