pytorch Torch 矩阵等式和运算

r9f1avp5  于 2022-11-29  发布在  其他
关注(0)|答案(3)|浏览(207)

我想做一个类似于矩阵乘法的运算,只是我想检查是否相等,而不是乘法。我想达到的效果类似于以下内容:

a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]

有没有一种方法可以让我使用einsum,或者pytorch中的其他函数来有效地实现上述功能?

uurv41yg

uurv41yg1#

您可以使用torch.repeat和torch.repeat_interleave:

a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

mask = a.repeat_interleave(3, dim=0) == b.repeat((2, 1))

torch.sum(mask, axis=1).reshape(a.shape)

# output
tensor([[3, 0, 0],
        [0, 3, 0]])
sxissh06

sxissh062#

您可以使用broadcasting执行相同的操作,例如,使用

result = (a[:, None, :] == b[None, :, :]).sum(dim=2)

这里None只是引入了一个虚拟尺寸-或者你可以使用不太直观的.unsqueeze()来代替。

dzjeubhm

dzjeubhm3#

矩阵乘法是ij,jk->ik在einsum符号中,所有这些运算在不同的详细程度下是等效的:

a @ b
torch.einsum("ij,jk", a, b)
torch.einsum("ij,jk->ik", a, b)
(a[:,:,None] * b[None,:,:]).sum(1)

“将ik维度相乘,并减去j维度”

i, j, k             i,    j, k
a: (2, 3)        =>    (2,    3, None)
b:    (3, 3)           (None, 3, 3)

现在从该函数分解应当清楚,乘法可以用任何二进制运算来代替,例如相等运算。
不幸的是,pytorch中没有通用形式的einsum(AFAIK)来交换“开箱即用”的乘法运算。不过,有einops库,它基本上是PyTorch等深度学习框架的 Package 器。

相关问题