我有两个2-dTensor,它们通过广播对齐,所以如果我加/减它们,我会产生一个巨大的3-dTensor。虽然我真的不需要这样,因为我将在一维上执行mean
。在这个演示中,我解压Tensor来显示它们如何对齐,但它们在其他方面是2-d的。
x = torch.tensor(...) # (batch , 1, B)
y = torch.tensor(...) # (1, , A, B)
out = torch.cos(x - y).mean(dim=2) # (batch, B)
可能的解决方案:
- 一个代数简化,但我的生活我还没有解决这个问题。
- 一些PyTorch原语会有帮助吗?这是余弦相似性,但是,有点不同于
torch.cosine_similarity
。我将它应用于复数的.angle()
。 - 高效循环的自定义C/CPython代码。
- 其他人?
1条答案
按热度按时间9wbgstp71#
为了保存内存,我建议使用
torch.einsum
:我们可以利用三角恒等式在这种情况下,我们可以应用
einsum
,其中通常的求和将是求平均值,而两个乘积之间的+
将是稍后的另一个运算,因此简而言之虽然测量内存消耗有点繁琐,但我还是采用了测量时间的方法。在这里,您可以看到您最初的方法和我对各种大小的输入的建议。(生成这些图的脚本附在下面)。