返回非平面索引的numpy数组的Argmax,但在条件数组上

vq8itlhq  于 2023-06-06  发布在  其他
关注(0)|答案(2)|浏览(137)

我想特别关注以下question

如何在合适的a索引中获取a[...]的argmax

>>> a = (np.random.random((10, 10))*10).astype(int)
>>> a
array([[4, 1, 7, 4, 3, 3, 8, 9, 3, 0],
       [7, 7, 8, 9, 9, 6, 1, 4, 2, 0],
       [6, 9, 4, 9, 2, 7, 9, 0, 8, 6],
       [2, 4, 7, 8, 0, 6, 0, 7, 1, 8],
       [7, 9, 7, 0, 1, 2, 3, 7, 9, 6],
       [7, 1, 1, 0, 5, 1, 8, 8, 5, 5],
       [5, 4, 3, 0, 0, 4, 4, 5, 5, 4],
       [9, 5, 0, 5, 8, 1, 6, 4, 8, 5],
       [5, 8, 0, 8, 2, 6, 4, 9, 5, 1],
       [2, 5, 0, 1, 4, 0, 0, 9, 6, 4]])
>>> np.unravel_index(a.argmax(), a.shape)
(0, 7)
>>> np.unravel_index(a[a>5].argmax(), a.shape)
(0, 2)
>>> np.unravel_index(a[a>5].argmax(), a[a>5].shape)
(2,)
mf98qq94

mf98qq941#

你可以考虑使用masked API:

import numpy as np

arr = np.random.randint(10, 100, size=(10, 10))
mask = arr > 50

# Note: values True in `mask` are considered "invalid"
# or "masked", and thus disregarded. This is opposite
# the behavior in boolean mask indexing, where only
# the True values are retrieved.

masked = np.ma.array(arr, mask=mask)
out = np.unravel_index(masked.argmax(), masked.shape)

结果:

>>> arr
array([[58, 75, 78, 46, 89, 54, 35, 18, 13, 99],
       [30, 11, 24, 10, 15, 41, 40, 15, 94, 28],
       [84, 84, 83, 72, 39, 22, 57, 51, 91, 23],
       [54, 99, 72, 63, 30, 14, 91, 46, 98, 74],
       [27, 90, 93, 25, 41, 82, 39, 42, 57, 64],
       [98, 63, 79, 13, 91, 12, 36, 71, 95, 30],
       [23, 34, 51, 19, 37, 31, 58, 65, 20, 31],
       [26, 73, 67, 21, 67, 89, 72, 80, 11, 48],
       [87, 64, 38, 74, 60, 31, 30, 54, 71, 44],
       [78, 94, 62, 38, 79, 23, 61, 62, 18, 25]])
>>> print(masked)
[[-- -- -- 46 -- -- 35 18 13 --]
 [30 11 24 10 15 41 40 15 -- 28]
 [-- -- -- -- 39 22 -- -- -- 23]
 [-- -- -- -- 30 14 -- 46 -- --]
 [27 -- -- 25 41 -- 39 42 -- --]
 [-- -- -- 13 -- 12 36 -- -- 30]
 [23 34 -- 19 37 31 -- -- 20 31]
 [26 -- -- 21 -- -- -- -- 11 48]
 [-- -- 38 -- -- 31 30 -- -- 44]
 [-- -- -- 38 -- 23 -- -- 18 25]]
>>> out
(7, 9)
>>> arr[out]
48
xeufq47z

xeufq47z2#

对于一个面具,什么是:

np.where( (a > 5) & (a == a[a>5].max()))

mask = a > 5
np.where( mask & (a == a[mask].max()))

相关问题