我正在尝试加载检查点,但收到错误消息。如何保存和加载检查点:
1.节省
torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, chkpt_path)
字符串
1.加载
chkpt = torch.load(chkpt_path, map_location=lambda storage, loc: storage)
net.load_state_dict(chkpt['model_state_dict'])
型
在执行以下行“net.load_state_dict(chkpt 'model_state_dict']”后,我的所有图层都多次出现以下错误:While copying the parameter named "residual_layer.residual_block ...", whose dimensions in the model are torch.Size([1, 524, 32, 2]) and whose dimensions in the checkpoint are torch.Size([1, 524, 32, 2]), an exception occurred : ('unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.',).
如何使用clone()函数正确加载检查点字典?我在cpu上训练了模型,并希望在cpu上进行推理。
我试图将clone函数添加到load_state_dict()中,但没有成功。
2条答案
按热度按时间a0x5cqrl1#
问题出在你的模型中,你的模型中有一个Tensor,它使用相同的内存来存储多个值,并且有东西试图写入它。
例如,我可以使用
expand
创建一个Tensor,其中多行占用同一内存,并且修改一个值实际上会更改两个值:字符串
你的模型中有类似的东西,所以我假设检查点加载代码不知道如何处理它,但是检查堆栈跟踪来确定。它要求你克隆有问题的Tensor(在你的模型中)。例如
x = x.clone()
我可以修改上面的代码,添加
clone
,以避免重复的99型
你的问题可能不是
expand
,有很多pytorch函数可能会导致不连续的Tensor。仅供参考,你也可以使用.contiguous()
来代替.clone()
。另外,这里有一个对上面的Tensor进行就地修改的例子,复制你的错误消息。同样,有很多方法可以获得这样的内存共享Tensor,也有很多方法可以就地修改它们:
型
xriantvc2#
我的假设是,您使用DataParallel来训练模型。如果您的模型具有名为“weight”的参数,则在使用DataParallel时,状态字典中的键将是“module. weight”。DataParallel会自动使用具有此前缀的模块 Package 模型。
所以试试看:
字符串
在
net.load_state_dict(chkpt['model_state_dict'])
之前。“