我有两个2DTensor,在不同的长度,都是相同的原始2DTensor的不同子集,我想找到所有匹配的“行”
例如
A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)
字符串
我只看到numpy解决方案,使用dtype作为dict,不适用于pytorch。
下面是我在numpy中的操作
arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)
型
3条答案
按热度按时间p5fdfcr11#
这个答案是在OP用其他限制更新问题之前发布的,这些限制大大改变了问题。
TL;DR你可以这样做:
字符串
首先,假设你有:
型
我们可以检查
A == B
返回:型
所以我们想要的是在这些行中,它们都是
True
。为此,我们可以使用.all()
操作并指定感兴趣的维度,在我们的例子中为1
:型
你实际上想知道的是
True
s在哪里。为此,我们可以得到torch.where()
函数的第一个输出:型
cwdobuhd2#
如果A和B是2DTensor,下面的代码查找索引使得
A[indices] == B
。如果多个索引满足此条件,则返回找到的第一个索引。如果A中不存在B的所有元素,则忽略相应的索引。字符串
3wabscal3#
如果两个Tensor的行数不同,那么我们不能直接比较Tensor。我们必须首先在其中一个Tensor上添加一个虚拟维度。一步一步:
1.创建Tensor
字符串
1.添加虚拟尺寸并获得成对比较:
型
输出将是一个4x3x3Tensor,其中4个
i
子Tensor中的每一个都是A[i] == B
。1.获取指示哪些索引具有完美“行”匹配的Tensor:
型
输出是一个4x3的Tensor。包含
True
元素的行包含完美的行匹配。1.获取具有完美匹配的行:
型
1.最后,获取
A
中与B
中匹配的行的索引:型
要获得B中与A中匹配的行的索引,只需交换Tensor:
型
这个问题有一个类似的numpy版本here,我的答案非常受the answer there by Ehsan的启发。