numpy 获取2D数组的行方向最大值的列索引(随机平局决胜)

piwo6bdm  于 2023-01-13  发布在  其他
关注(0)|答案(2)|浏览(120)

给定一个2D numpy数组,我想用每一行的最大值的列索引构造一个数组,到目前为止,arr.argmax(1)工作得很好,然而,对于我的特定情况,对于某些行,可能有2列或更多列包含最大值,在这种情况下,我想随机选择一个列索引(而不是像.argmax(1)那样选择第一个索引)。
例如,对于以下arr

  1. arr = np.array([
  2. [0, 1, 0],
  3. [1, 1, 0],
  4. [2, 1, 3],
  5. [3, 2, 2]
  6. ])

可能有两种结果:array([1, 0, 2, 0])array([1, 1, 2, 0]),每个都以1/2的概率选择。
我的代码使用列表解析返回预期的输出:

  1. idx = np.arange(arr.shape[1])
  2. ans = [np.random.choice(idx[ix]) for ix in arr == arr.max(1, keepdims=True)]

但是我在寻找一个优化的numpy解决方案,换句话说,我如何用numpy方法替换列表解析,使代码适用于更大的数组?

zsohkypk

zsohkypk1#

按如下方式使用scipy.stats.rankdataapply_along_axis

  1. import numpy as np
  2. from scipy.stats import rankdata
  3. ranks = rankdata(-arr, axis = 1, method = "min")
  4. func = lambda x: np.random.choice(np.where(x==1)[0])
  5. idx = np.apply_along_axis(func, 1, ranks)
  6. print(idx)

它返回[1 0 2 0]或[1 1 2 0]。
其主要思想是rankdata计算每一行中每一个值的秩,最大值将为1。func随机选择一个对应值为1的索引。最后,apply_along_axisfunc应用于arr的每一行。

mwg9r5ms

mwg9r5ms2#

在得到一些建议后,我下线了,当我们把标记行最大值的布尔数组乘以一个相同形状的随机数组时,最大值的随机化是可能的,然后剩下的就是一个简单的argmax(1)调用。

  1. # boolean array that flags maximum values of each row
  2. mxs = arr == arr.max(1, keepdims=True)
  3. # random array where non-maximum values are zero and maximum values are random values
  4. random_arr = np.random.rand(*arr.shape) * mxs
  5. # row-wise maximum of the auxiliary array
  6. ans = random_arr.argmax(1)

timeit测试显示,对于(507_563, 12)形状的数据,这段代码在我的机器上运行了大约172 ms,而问题中的循环运行了11秒,因此速度大约快了63倍。

相关问题