在PyTorch中,当添加两个Tensor然后立即崩溃时,如何避免昂贵的广播?

bnl4lu3b  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(117)

我有两个2-dTensor,它们通过广播对齐,所以如果我加/减它们,我会产生一个巨大的3-dTensor。虽然我真的不需要这样,因为我将在一维上执行mean。在这个演示中,我解压Tensor来显示它们如何对齐,但它们在其他方面是2-d的。

x = torch.tensor(...)              # (batch , 1,  B)
y = torch.tensor(...)              # (1,    , A,  B)
out = torch.cos(x - y).mean(dim=2) # (batch, B)

可能的解决方案:

  • 一个代数简化,但我的生活我还没有解决这个问题。
  • 一些PyTorch原语会有帮助吗?这是余弦相似性,但是,有点不同于torch.cosine_similarity。我将它应用于复数的.angle()
  • 高效循环的自定义C/CPython代码。
  • 其他人?
9wbgstp7

9wbgstp71#

为了保存内存,我建议使用torch.einsum:我们可以利用三角恒等式

cos(x-y) = cos(x)*cos(y) + sin(x)*sin(y)

在这种情况下,我们可以应用einsum,其中通常的求和将是求平均值,而两个乘积之间的+将是稍后的另一个运算,因此简而言之

xs, ys = torch.sin(x), torch.sin(y)
xc, yc = torch.cos(x), torch.cos(y)
# use einsum for sin/cos products and averaging sum, use + for sum of products: 
out = (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]

虽然测量内存消耗有点繁琐,但我还是采用了测量时间的方法。在这里,您可以看到您最初的方法和我对各种大小的输入的建议。(生成这些图的脚本附在下面)。

import matplotlib.pyplot as plt
import torch
import time

def main():
    ns = torch.logspace(1, 3.2, 20).to(torch.long)
    tns = []; tes = []
    for n in ns:
        tn, te = compare(n)
        tns.append(tn); tes.append(te)
    plt.loglog(ns, tns, ':.'); plt.loglog(ns, tes, '.-'); plt.loglog(ns, 1e-6*ns**1, ':'); plt.loglog(ns, 1e-6*ns**2, ':'); plt.legend(['naive', 'einsum', 'x^1', 'x^2']);
    plt.show()

def compare(n):
    batch = a = b = n
    x = torch.zeros((batch, b)) # (batch , 1,  B)
    y = torch.zeros((a, b))  # (1,    , A,  B)
    t = time.perf_counter(); ra = af(x.unsqueeze(1), y.unsqueeze(0)); print('naive method', tn := time.perf_counter() - t)
    t = time.perf_counter(); rb = bf(x, y); print('einsum method', te := time.perf_counter() - t)
    print((ra-rb).abs().max()) # verify we have same results
    return tn, te

def af(x, y):
    return torch.cos(x - y).mean(dim=2) 

def bf(x, y):
    xs, ys = torch.sin(x), torch.sin(y)
    xc, yc = torch.cos(x), torch.cos(y)
    return (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]

main()

相关问题