pytorch 加载模型检查点的模型状态字典时出错

cbjzeqam  于 9个月前  发布在  其他
关注(0)|答案(2)|浏览(142)

我正在尝试加载检查点,但收到错误消息。如何保存和加载检查点:
1.节省

torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, chkpt_path)

字符串
1.加载

chkpt = torch.load(chkpt_path, map_location=lambda storage, loc: storage)
net.load_state_dict(chkpt['model_state_dict'])


在执行以下行“net.load_state_dict(chkpt 'model_state_dict']”后,我的所有图层都多次出现以下错误:While copying the parameter named "residual_layer.residual_block ...", whose dimensions in the model are torch.Size([1, 524, 32, 2]) and whose dimensions in the checkpoint are torch.Size([1, 524, 32, 2]), an exception occurred : ('unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.',).
如何使用clone()函数正确加载检查点字典?我在cpu上训练了模型,并希望在cpu上进行推理。
我试图将clone函数添加到load_state_dict()中,但没有成功。

a0x5cqrl

a0x5cqrl1#

问题出在你的模型中,你的模型中有一个Tensor,它使用相同的内存来存储多个值,并且有东西试图写入它。
例如,我可以使用expand创建一个Tensor,其中多行占用同一内存,并且修改一个值实际上会更改两个值:

>>> x = torch.tensor([1,2,3]).expand( (2,3) )
>>> x
tensor([[1, 2, 3],
        [1, 2, 3]])
>>> x[0,0]
tensor(1)
>>> x[0,0] = 99
>>> x
tensor([[99,  2,  3],
        [99,  2,  3]])
>>>

字符串
你的模型中有类似的东西,所以我假设检查点加载代码不知道如何处理它,但是检查堆栈跟踪来确定。它要求你克隆有问题的Tensor(在你的模型中)。例如x = x.clone()
我可以修改上面的代码,添加clone,以避免重复的99

>>> x = torch.tensor([1,2,3]).expand( (2,3) ).clone()
>>> x[0,0]=99
>>> x
tensor([[99,  2,  3],
        [ 1,  2,  3]])


你的问题可能不是expand,有很多pytorch函数可能会导致不连续的Tensor。仅供参考,你也可以使用.contiguous()来代替.clone()
另外,这里有一个对上面的Tensor进行就地修改的例子,复制你的错误消息。同样,有很多方法可以获得这样的内存共享Tensor,也有很多方法可以就地修改它们:

>>> x = torch.tensor([1,2,3]).expand( (2,3) )
>>> x.mul_(5)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.

xriantvc

xriantvc2#

我的假设是,您使用DataParallel来训练模型。如果您的模型具有名为“weight”的参数,则在使用DataParallel时,状态字典中的键将是“module. weight”。DataParallel会自动使用具有此前缀的模块 Package 模型。
所以试试看:

if 'module.' in list(chkpt['model_state_dict'].keys())[0]:
new_state_dict = {}
for key, value in chkpt['model_state_dict'].items():
    new_key = key.replace('module.', '')
    new_state_dict[new_key] = value
chkpt['model_state_dict'] = new_state_dict

字符串
net.load_state_dict(chkpt['model_state_dict'])之前。“

相关问题