获取2D numpy数组中大于阈值的元素的索引

lmyy7pcs  于 2023-03-12  发布在  其他
关注(0)|答案(2)|浏览(235)

我有一个二维numpy数组:

x = np.array([
 [  1.92043482e-04,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   2.41005634e-03,   0.00000000e+00,
    7.19330120e-04,   0.00000000e+00,   0.00000000e+00,   1.42886875e-04,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   9.79279411e-05,   7.88888657e-04,   0.00000000e+00,
    0.00000000e+00,   1.40425916e-01,   0.00000000e+00,   1.13955893e-02,
    7.36868947e-03,   3.67091988e-04,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   1.72037105e-03,   1.72377961e-03,
    0.00000000e+00,   0.00000000e+00,   1.19532061e-01,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   3.37249481e-04,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   1.75111492e-03,   0.00000000e+00,
    0.00000000e+00,   1.12639313e-02],
 [  0.00000000e+00,   0.00000000e+00,   1.10271735e-04,   5.98736562e-04,
    6.77961628e-04,   7.49569659e-04,   0.00000000e+00,   0.00000000e+00,
    2.91697850e-03,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   3.30257021e-04,   2.46629275e-04,
    0.00000000e+00,   1.87586441e-02,   6.49103144e-04,   0.00000000e+00,
    1.19046355e-04,   0.00000000e+00,   0.00000000e+00,   2.69499898e-03,
    1.48525386e-02,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   1.18803119e-03,
    3.93100829e-04,   0.00000000e+00,   3.76245304e-04,   2.79537738e-02,
    0.00000000e+00,   1.20738457e-03,   9.74669064e-06,   7.18680093e-04,
    1.61546793e-02,   3.49360861e-04,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00]])

如何获得大于0.01的元素的索引?
现在,我正在执行t = np.argmax(x, axis=1)来获取每个变量的最大值的索引,结果是:[21 35] .如何实现上述目标?

b09cbbtk

b09cbbtk1#

可以使用np.argwhere返回数组中所有与布尔条件匹配的项的索引:

>>> x = np.array([[0,0.2,0.5],[0.05,0.01,0]])

>>> np.argwhere(x > 0.01)
array([[0, 1],
       [0, 2],
       [1, 0]])
yqlxgs2m

yqlxgs2m2#

也可以使用np.nonzero()来获得数组元组,x的每个维度对应一个数组元组,其中x包含条件为True的索引。

x_indices, y_indices = np.nonzero(x > 0.01)
# (array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int64), array([21, 23, 34, 49, 17, 24, 35, 40], dtype=int64))

它的一个优点是可以立即用来索引数组,例如,如果我们想过滤大于0.01的元素,那么

x[np.nonzero(x>0.01)]

nonzero按维度对索引进行分组,而argwhere按元素进行分组(这只是从另一个Angular 看同一件事),因此以下为True:

(np.argwhere(x>0.01).T == np.nonzero(x>0.01)).all()   # True

相关问题