我尝试使用past_key_values来加速推理:
import torch
from transformers import GPT2LMHeadModel
torch.set_default_device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
model.to("cuda")
seq = torch.tensor([1, 2, 3, 4, 5])
original_out = model(input_ids=seq).logits
seq2 = torch.tensor([1, 2, 3])
key_values = model(input_ids=seq2, use_cache=True).past_key_values
new_seq = torch.tensor([4, 5])
magic = model(input_ids=new_seq, past_key_values=key_values).logits
print(torch.equal(original_out[-1, :], magic[-1, :]))
字符串
但这将返回False,而我希望它返回True。
1条答案
按热度按时间tyg4sfes1#
你的代码很好,但你遇到了一些浮点精度问题。torch.equal检查两个Tensor是否具有相同的形状和完全相同的值,但你的两个变量略有不同:
字符串
输出量:
型
我推荐使用torch.allclose来比较两个Tensor,因为它考虑了一些公差:
型
输出量:
型