我想做一个类似于矩阵乘法的运算,只是我想检查是否相等,而不是乘法。我想达到的效果类似于以下内容:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]
有没有一种方法可以让我使用einsum,或者pytorch中的其他函数来有效地实现上述功能?
3条答案
按热度按时间uurv41yg1#
您可以使用torch.repeat和torch.repeat_interleave:
sxissh062#
您可以使用broadcasting执行相同的操作,例如,使用
这里
None
只是引入了一个虚拟尺寸-或者你可以使用不太直观的.unsqueeze()
来代替。dzjeubhm3#
矩阵乘法是
ij,jk->ik
在einsum符号中,所有这些运算在不同的详细程度下是等效的:“将
i
与k
维度相乘,并减去j
维度”现在从该函数分解应当清楚,乘法可以用任何二进制运算来代替,例如相等运算。
不幸的是,pytorch中没有通用形式的einsum(AFAIK)来交换“开箱即用”的乘法运算。不过,有
einops
库,它基本上是PyTorch等深度学习框架的 Package 器。