python 检索CSR矩阵的值

5us2dqdw  于 2022-12-02  发布在  Python
关注(0)|答案(2)|浏览(185)

我有一个CSR矩阵,我希望能够检索列索引和值。
下面是我创建矩阵的方法(使用scipy.sparse中的csr_matrix):

  1. indptr = np.empty(nbr_of_rows + 1) # nbr_of_rows = 134,465
  2. indptr[0] = 0
  3. for i in range(1, len(indptr)):
  4. indptr[i] = indptr[i-1] + len(data[i-1]) # type(data) = list ; len(data) = 134,465 ; type(data[0]) = numpy.darray (each subarray has a different length)
  5. data = np.concatenate(data).ravel() # now I have type(data) = numpy.darray ; len(data) = 2,821,574
  6. ind = np.concatenante(ind).ravel # same than above
  7. X = csr_matrix((data, ind, indptr), shape=(nbr_of_rows, nbr_of_columns)) # nbr_of_columns = 3,991
  8. print(f"The matrix has a shape of {X.shape} and a sparsity of {(1 - (X.nnz / (X.shape[0] * X.shape[1]))): .2%}.")
  9. # OUT: The matrix has a shape of (134465, 3991) and a sparsity of 99.47%.

到目前为止一切顺利(至少我是这么认为的)。但是现在,即使我设法检索了列索引,我也不能成功地检索值:

  1. np.alltrue(ind == X.nonzero()[1]) # True
  2. np.alltrue(data == X[X.nonzero()]) # False

当我看得更深时,我发现我得到了 * 几乎 * 所有的值(只有少量的错误):

  1. len(data) == len(X[X.nonzero()].tolist()[0]) # True
  2. len(np.argwhere((data==X[X.nonzero()]) == False)) # 2184

因此,在总共2,821,574个值中,我“只”得到了2,184个错误值。
有人能帮我从我的CSR矩阵中获得所有正确的值吗?

zf9nrax1

zf9nrax11#

根据您存储在矩阵中的值的类型(numpy.float64numpy.int64),下面的帖子可能会回答您的问题:https://github.com/scipy/scipy/issues/13329#issuecomment-753541268
特别是,注解“* 显然,当数据是numpy数组而不是列表时,我没有得到错误。*”表明将data作为numpy.array而不是list可以解决您的问题。
希望这至少能让你走上正轨。

esbemjvw

esbemjvw2#

如果没有data,我就无法复制您的问题,而且即使使用这么大的数组,我也可能不想这样做。
但是我将试着说明当用这种方法构造一个矩阵时,我期望发生什么。从另一个问题开始,我在Ipython会话中有一个小矩阵:

  1. In [60]: Mx
  2. Out[60]:
  3. <1x3 sparse matrix of type '<class 'numpy.intc'>'
  4. with 2 stored elements in Compressed Sparse Row format>
  5. In [61]: Mx.A
  6. Out[61]: array([[0, 1, 2]], dtype=int32)

nonzero返回coo索引,行,列

  1. In [62]: Mx.nonzero()
  2. Out[62]: (array([0, 0], dtype=int32), array([1, 2], dtype=int32))

csr属性包括:

  1. In [63]: Mx.data,Mx.indices,Mx.indptr
  2. Out[63]:
  3. (array([1, 2], dtype=int32),
  4. array([1, 2], dtype=int32),
  5. array([0, 2], dtype=int32))

现在,让我们使用Mx的属性创建一个新的矩阵。假设您正确地构建了indptrindicesdata,这应该可以模仿您所做的操作:

  1. In [64]: newM = sparse.csr_matrix((Mx.data, Mx.indices, Mx.indptr))
  2. In [65]: newM.A
  3. Out[65]: array([[0, 1, 2]], dtype=int32)

data两个矩阵之间的匹配:

  1. In [68]: Mx.data==newM.data
  2. Out[68]: array([ True, True])

dataid不匹配,但它们的基匹配。请参阅我最近的回答,了解为什么这是相关的
https://stackoverflow.com/a/74543855/901925

  1. In [75]: id(Mx.data.base), id(newM.data.base)
  2. Out[75]: (2255407394864, 2255407394864)

这意味着对newA的更改将出现在Mx中:

  1. In [77]: newM[0,1] = 100
  2. In [78]: newM.A
  3. Out[78]: array([[ 0, 100, 2]], dtype=int32)
  4. In [79]: Mx.A
  5. Out[79]: array([[ 0, 100, 2]], dtype=int32)

富勒试验

让我们对您的代码进行一个小规模测试:

  1. In [92]: data = np.array([[1.23,2],[3],[]],object); ind = np.array([[1,2],[3],[]],object)
  2. ...: indptr = np.empty(4)
  3. ...: indptr[0] = 0
  4. ...: for i in range(1, 4):
  5. ...: indptr[i] = indptr[i-1] + len(data[i-1])
  6. ...: data = np.concatenate(data).ravel()
  7. ...: ind = np.concatenate(ind).ravel() # same than above
  8. In [93]: data,ind,indptr
  9. Out[93]: (array([1.23, 2. , 3. ]), array([1., 2., 3.]), array([0., 2., 3., 3.]))

而稀疏矩阵:

  1. In [94]: X = sparse.csr_matrix((data, ind, indptr), shape=(3,3))
  2. In [95]: X
  3. Out[95]:
  4. <3x3 sparse matrix of type '<class 'numpy.float64'>'
  5. with 3 stored elements in Compressed Sparse Row format>

data匹配项:

  1. In [96]: X.data
  2. Out[96]: array([1.23, 2. , 3. ])
  3. In [97]: data == X.data
  4. Out[97]: array([ True, True, True])

且实际上是view

  1. In [98]: data[1]+=.23; data
  2. Out[98]: array([1.23, 2.23, 3. ])
  3. In [99]: X.A
  4. Out[99]:
  5. array([[0. , 1.23, 2.23],
  6. [0. , 0. , 0. ],
  7. [3. , 0. , 0. ]])

哎呀

我在指定X形状时出错:

  1. In [110]: X = sparse.csr_matrix((data, ind, indptr), shape=(3,4))
  2. In [111]: X.A
  3. Out[111]:
  4. array([[0. , 1.23, 2.23, 0. ],
  5. [0. , 0. , 0. , 3. ],
  6. [0. , 0. , 0. , 0. ]])
  7. In [112]: X.data
  8. Out[112]: array([1.23, 2.23, 3. ])
  9. In [113]: X.nonzero()
  10. Out[113]: (array([0, 0, 1], dtype=int32), array([1, 2, 3], dtype=int32))
  11. In [114]: X[X.nonzero()]
  12. Out[114]: matrix([[1.23, 2.23, 3. ]])
  13. In [115]: data
  14. Out[115]: array([1.23, 2.23, 3. ])
  15. In [116]: data == X[X.nonzero()]
  16. Out[116]: matrix([[ True, True, True]])
展开查看全部

相关问题