如何在pytorch中使用保存的权重加载部分模型

qnyhuwrf  于 2023-03-30  发布在  其他
关注(0)|答案(1)|浏览(151)

我已经在Pytorch中对一些图像训练了一个模型。训练后我保存了模型的权重。我的模型的代码如下:
基于ImageClassificationBase类的自定义模型-〉

import torch
import torch.nn as nn

# Define the new model with the same architecture as the trained model

class NewModel(nn.Module):
    def __init__(self):
        super(NewModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
        self.fc1 = nn.Linear(128 * 30 * 30, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 128 * 30 * 30)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

现在我想把这些权重转移到一个新的模型中,这个模型的结构与之前的模型相同;但我只想在展平操作之前使用权重;之后,我想根据我的新任务重新训练线性层。

我的目标:

# Load the saved weights from the trained model
trained_model_path = "/content/model_weights.pth"
trained_model_state_dict = torch.load(trained_model_path)

# Create the new model
new_model = Classifier_model()

# Remove the keys corresponding to the layers that you don't want to initialize
new_model_state_dict = new_model.state_dict()
for key in list(new_model_state_dict.keys()):
    if key.startswith('fc'):
         new_model_state_dict[key]

# Update the state dictionary with the learned weights
trained_model_state_dict.update(new_model_state_dict)

new_model.load_state_dict(trained_model_state_dict)

但这显示了错误,因为我不能只删除fc层对应的权重,而我需要用随机数替换它们。
有人能建议我怎么做吗?

kmbjn2e3

kmbjn2e31#

您只需加载已训练模型的状态字典,然后弹出以'fc'开头的键:

>>> trained = torch.load('model.pth')
>>> trained_trim = {k:v for k, v in trained.items() if not k.startswith('fc')}

然后,您可以继续使用load_state_dict在新初始化的模型上加载此部分状态字典,并将严格参数设置为False

>>> model = NewModel()
>>> model.load_state_dict(trained_trim , strict=False)

它不会对丢失的键抛出错误,而是提供一个丢失的和不期望的键的列表。

_IncompatibleKeys(
   missing_keys=['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'], 
   unexpected_keys=[])

相关问题