与保存的模型相比,加载的PyTorch模型具有不同的结果

sr4lhrrt  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(165)

我有一个python脚本,用于训练和测试CNN模型。模型权重/参数在测试后通过以下方式保存:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, path + filename)

保存后,我立即通过使用函数加载模型:

model_load = create_model(cnn_type="vgg", numberofclasses=len(cases))

然后,我通过以下方式加载模型权重/参数:

model_load.load_state_dict(torch.load(filePath+filename), strict = False)    
model_load.eval()

最后,我将保存模型之前使用的相同测试数据提供给这个模型。
问题是,当我比较储存之前和载入之后的模型测试结果时,测试结果并不相同。我的直觉是,由于strict = False,某些参数并未传递至模型。但是,当我设定strict = True时,我收到错误。是否有解决方法?
错误消息为:

RuntimeError: Error(s) in loading state_dict for CNN:
        Missing key(s) in state_dict: "linear.weight", "linear.bias", "linear 2.weight", "linea r2.bias", "linear 3.weight", "linear3.bias". Unexpected key(s) in state_dict: "state_dict", "optimizer".
chhqkbe1

chhqkbe11#

您正在加载一个包含模型状态和优化程序状态的字典。根据错误堆栈跟踪,以下操作应该可以解决此问题:

>>> model_state = torch.load(filePath+filename)['state_dict']
>>> model_load.load_state_dict(model_state, strict=True)

相关问题