对于SciPy稀疏矩阵,如何获取低于阈值的值的索引

1rhkuytd  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(134)

当使用条件语句过滤SciPy稀疏数组中的值时,如何获得这些值的索引?
我尝试使用apply条件语句到csc_array().data来获取索引,但是它们与csc_array().nonzero()索引不匹配。下面是我所面临的问题的一个例子:

import numpy as np
from scipy.sparse import dok_array, csc_array

m = dok_array((1000, 1000))
for i, j in zip(np.random.randint(0, 1000, 100), np.random.randint(0, 1000, 100)):
    m[i, j] = np.random.random()

threshold = 0.3
tmp = csc_array(m)
mask = tmp.data < threshold
i, j = tmp.nonzero()
i_mask, j_mask = i[mask], j[mask]
assert np.alltrue(tmp[i_mask, j_mask] < threshold), "This fails!!!"
5w9g7ksd

5w9g7ksd1#

要解决csc_array().datacsc_array().nonzero()的顺序不匹配的问题,只需在整个过程中使用nonzero索引即可,如下所示:

import numpy as np
from scipy.sparse import dok_array, csc_array

m = dok_array((1000, 1000))
for i, j in zip(np.random.randint(0, 1000, 100), np.random.randint(0, 1000, 100)):
    m[i, j] = np.random.random()

threshold = 0.3
tmp = csc_array(m)
i, j = tmp.nonzero()
mask = tmp[i, j] < threshold
i_mask, j_mask = i[mask], j[mask]
tmp[i_mask, j_mask] = 0
tmp.eliminate_zeros()
assert np.alltrue(threshold < tmp.data), "Should not see this!!!" 

m = dok_array(tmp)

相关问题