pytorch 边索引Tensor的几何稀疏邻接矩阵

zf2sa74q  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(350)

我的数据对象有data.adj_t参数,它给出了稀疏邻接矩阵,我如何从这个矩阵中得到大小为[2, num_edges]edge_indexTensor?

wnvonmuf

wnvonmuf1#

docs中所示:
由于此功能仍处于试验阶段,因此某些操作(* 例如 * 图形池方法)可能仍需要您输入edge_index。您可以通过以下方式将adj_t转换回(edge_index, edge_attr)

  1. row, col, edge_attr = adj_t.t().coo()
  2. edge_index = torch.stack([row, col], dim=0)
nzrxty8p

nzrxty8p2#

您可以使用torch_geometric.utils.convert.from_scipy_sparse_matrix

  1. >>> from torch_geometric.utils.convert import from_scipy_sparse_matrix
  2. >>> edge_index = torch.tensor([
  3. ... [0, 1, 1, 2, 2, 3],
  4. ... [1, 0, 2, 1, 3, 2],
  5. >>> ])
  6. >>> adj = to_scipy_sparse_matrix(edge_index)
  7. >>> # `edge_index` and `edge_weight` are both returned
  8. >>> from_scipy_sparse_matrix(adj)
  9. (tensor([[0, 1, 1, 2, 2, 3],
  10. [1, 0, 2, 1, 3, 2]]),
  11. tensor([1., 1., 1., 1., 1., 1.]))

相关问题