PyTorch矩阵乘法不考虑切片

o7jaxewo  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(112)

我不得不批处理长时间的输入,注意到批处理和非批处理结果之间的差异。最后,我发现了第一个差异,结果如下:

import torch

n = 20

vec = torch.rand(n, 20)
a = torch.rand(30, 20)

for i in range(1, n+1):
    print(i, torch.equal(
        torch.nn.functional.linear(vec, a)[:i],
        torch.nn.functional.linear(vec[:i], a)))

产生输出:
这只是一个操作,当多次组合时(如在Transformer中),它可能会导致较大的发散,扩大torch.allclose输出True的atol。为什么会这样,我们能做点什么吗?

ecr0jaav

ecr0jaav1#

欢迎来到浮点运算的美丽世界!float运算引入舍入误差,并且矩阵乘法将它们累加到有效值。https://pytorch.org/docs/stable/notes/numerical_accuracy.html如果避免不精确的舍入,

vec = torch.floor (torch.rand(n, 20)*10)
a = torch.floor( torch.rand(30, 20)*10 )

你会得到所有的True-s。
可能的解决方案是使用torch.isclose
DEL

相关问题