scipy 将numpy中argpartition返回的值以外的所有值设置为0

yqkkidmi  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(139)

我想基于n*n距离矩阵为创建一个稀疏矩阵,保留每行距离矩阵中的k个最小值。我已经通过使用np.argpartition得到了正确的索引,但当我试图从这个矩阵创建一个掩码时,它只会选择对角线为True,其他所有内容为False。

nn_indices = np.argpartition(data, k - 1)[:, :k]
mask = np.isin(data, nn_indices)

你知道我如何使用argpartition的输出为这些索引创建布尔掩码吗?
例如-对于n = 4,k = 2

[[4, 6, 1, 3]
 [1, 5, 6, 7]
 [4, 7, 2, 3]
 [7, 1, 8, 2]]

参数分区输出:

[[2, 3]
 [0, 1]
 [2, 3]
 [1, 3]]

所需输出:

[[0, 0, 1, 3]
 [1, 5, 0, 0]
 [0, 0, 2, 3]
 [0, 1, 0, 2]]

我看了一下scipy.csr_matrix,但不知道如何对列和行数据进行排序。
任何帮助都将不胜感激!

q43xntqr

q43xntqr1#

分配一个全零数组,然后简单地填充它:

>>> mask = np.zeros(data.shape, bool)
>>> mask[np.arange(len(data))[:, None], nn_indices] = True
>>> mask
array([[False, False,  True,  True],
       [ True,  True, False, False],
       [False, False,  True,  True],
       [False,  True, False,  True]])
>>> np.where(mask, data, 0)
array([[0, 0, 1, 3],
       [1, 5, 0, 0],
       [0, 0, 2, 3],
       [0, 1, 0, 2]])

相关问题