如何在pytorch中删除所有出现在边中的节点

7gs2gvoe  于 2023-02-16  发布在  其他
关注(0)|答案(1)|浏览(165)

我有一个edge_index,希望删除其中的一个元素n = 3

edges = torch.tensor([
    [0, 1, 1, 2, 2, 3],
    [1, 0, 2, 1, 3, 2]])



nodes = torch.unique(edges)
n = nodes[-1]  # I want to remove this from edge_index

我试过了,但是没用

arr = edges[~(edges == [n]).all(axis=1)]
carvr3hs

carvr3hs1#

将代码更改为

arr = edges[(~(edges == n)).all(axis=0).unsqueeze(0).repeat(edges.shape[0], 1)].reshape(edges.shape[0], -1)

相关问题