pytorch 当我使用past_key_values时,gpt2 logits不同

hwamh0ep  于 2024-01-09  发布在  Git
关注(0)|答案(1)|浏览(206)

我尝试使用past_key_values来加速推理:

import torch
from transformers import GPT2LMHeadModel

torch.set_default_device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
model.to("cuda")

seq = torch.tensor([1, 2, 3, 4, 5])
original_out = model(input_ids=seq).logits

seq2 = torch.tensor([1, 2, 3])
key_values = model(input_ids=seq2, use_cache=True).past_key_values
new_seq = torch.tensor([4, 5])
magic = model(input_ids=new_seq, past_key_values=key_values).logits

print(torch.equal(original_out[-1, :], magic[-1, :]))

字符串
但这将返回False,而我希望它返回True。

tyg4sfes

tyg4sfes1#

你的代码很好,但你遇到了一些浮点精度问题。torch.equal检查两个Tensor是否具有相同的形状和完全相同的值,但你的两个变量略有不同:

import torch
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

seq = torch.tensor([1, 2, 3, 4, 5])
seq2 = torch.tensor([1, 2, 3])
new_seq = torch.tensor([4, 5])

with torch.inference_mode():
  original_out = model(input_ids=seq).logits[-1, :]
  key_values = model(input_ids=seq2, use_cache=True).past_key_values
  magic = model(input_ids=new_seq, past_key_values=key_values).logits[-1, :]

print(torch.equal(original_out, magic))
# Checking the difference of the first 20 elements
print((original_out[:20] - magic[:20]))

字符串
输出量:

False
tensor([ 7.6294e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -7.6294e-06,  1.5259e-05,  1.5259e-05,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.2888e-05,
        -7.6294e-06, -7.6294e-06,  1.5259e-05,  0.0000e+00,  7.6294e-06])


我推荐使用torch.allclose来比较两个Tensor,因为它考虑了一些公差:

print(torch.allclose(original_out, magic))


输出量:

True

相关问题