我有一个python脚本,用于训练和测试CNN模型。模型权重/参数在测试后通过以下方式保存:
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, path + filename)
保存后,我立即通过使用函数加载模型:
model_load = create_model(cnn_type="vgg", numberofclasses=len(cases))
然后,我通过以下方式加载模型权重/参数:
model_load.load_state_dict(torch.load(filePath+filename), strict = False)
model_load.eval()
最后,我将保存模型之前使用的相同测试数据提供给这个模型。
问题是,当我比较储存之前和载入之后的模型测试结果时,测试结果并不相同。我的直觉是,由于strict = False,某些参数并未传递至模型。但是,当我设定strict = True时,我收到错误。是否有解决方法?
错误消息为:
RuntimeError: Error(s) in loading state_dict for CNN:
Missing key(s) in state_dict: "linear.weight", "linear.bias", "linear 2.weight", "linea r2.bias", "linear 3.weight", "linear3.bias". Unexpected key(s) in state_dict: "state_dict", "optimizer".
1条答案
按热度按时间chhqkbe11#
您正在加载一个包含模型状态和优化程序状态的字典。根据错误堆栈跟踪,以下操作应该可以解决此问题: