我不得不批处理长时间的输入,注意到批处理和非批处理结果之间的差异。最后,我发现了第一个差异,结果如下:
import torch
n = 20
vec = torch.rand(n, 20)
a = torch.rand(30, 20)
for i in range(1, n+1):
print(i, torch.equal(
torch.nn.functional.linear(vec, a)[:i],
torch.nn.functional.linear(vec[:i], a)))
产生输出:
这只是一个操作,当多次组合时(如在Transformer中),它可能会导致较大的发散,扩大torch.allclose输出True的atol。为什么会这样,我们能做点什么吗?
1条答案
按热度按时间ecr0jaav1#
欢迎来到浮点运算的美丽世界!
float
运算引入舍入误差,并且矩阵乘法将它们累加到有效值。https://pytorch.org/docs/stable/notes/numerical_accuracy.html如果避免不精确的舍入,你会得到所有的
True
-s。可能的解决方案是使用
torch.isclose
。DEL