pytorch 在2个2DTensor中 Torch 查找匹配行的索引

pkbketx9  于 11个月前  发布在  其他
关注(0)|答案(3)|浏览(164)

我有两个2DTensor,在不同的长度,都是相同的原始2DTensor的不同子集,我想找到所有匹配的“行”
例如

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

字符串
我只看到numpy解决方案,使用dtype作为dict,不适用于pytorch。
下面是我在numpy中的操作

arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)

p5fdfcr1

p5fdfcr11#

这个答案是在OP用其他限制更新问题之前发布的,这些限制大大改变了问题。

TL;DR你可以这样做:

torch.where((A == B).all(dim=1))[0]

字符串
首先,假设你有:

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])


我们可以检查A == B返回:

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])


所以我们想要的是在这些行中,它们都是True。为此,我们可以使用.all()操作并指定感兴趣的维度,在我们的例子中为1

>>> (A == B).all(dim=1)
tensor([ True, False,  True])


你实际上想知道的是True s在哪里。为此,我们可以得到torch.where()函数的第一个输出:

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])

cwdobuhd

cwdobuhd2#

如果A和B是2DTensor,下面的代码查找索引使得A[indices] == B。如果多个索引满足此条件,则返回找到的第一个索引。如果A中不存在B的所有元素,则忽略相应的索引。

values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])

字符串

3wabscal

3wabscal3#

如果两个Tensor的行数不同,那么我们不能直接比较Tensor。我们必须首先在其中一个Tensor上添加一个虚拟维度。一步一步:
1.创建Tensor

A = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[3,3,3]])
B = torch.tensor([[1,2,3],[7,8,9],[4,4,4]])

字符串
1.添加虚拟尺寸并获得成对比较:

B == A.unsqueeze(1)


输出将是一个4x3x3Tensor,其中4个i子Tensor中的每一个都是A[i] == B
1.获取指示哪些索引具有完美“行”匹配的Tensor:

(B == A.unsqueeze(1)).all(-1)


输出是一个4x3的Tensor。包含True元素的行包含完美的行匹配。
1.获取具有完美匹配的行:

(B == A.unsqueeze(1)).all(-1).any(-1)


1.最后,获取A中与B中匹配的行的索引:

torch.where((B == A.unsqueeze(1)).all(-1).any(-1))[0]
>> tensor([0, 2])


要获得B中与A中匹配的行的索引,只需交换Tensor:

torch.where((A == B.unsqueeze(1)).all(-1).any(-1))[0]
>> tensor([0, 1])


这个问题有一个类似的numpy版本here,我的答案非常受the answer there by Ehsan的启发。

相关问题