我有一个edge_index,希望删除其中的一个元素n = 3
n = 3
edges = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) nodes = torch.unique(edges) n = nodes[-1] # I want to remove this from edge_index
我试过了,但是没用
arr = edges[~(edges == [n]).all(axis=1)]
carvr3hs1#
将代码更改为
arr = edges[(~(edges == n)).all(axis=0).unsqueeze(0).repeat(edges.shape[0], 1)].reshape(edges.shape[0], -1)
1条答案
按热度按时间carvr3hs1#
将代码更改为