我已经在Pytorch中对一些图像训练了一个模型。训练后我保存了模型的权重。我的模型的代码如下:
基于ImageClassificationBase类的自定义模型-〉
import torch
import torch.nn as nn
# Define the new model with the same architecture as the trained model
class NewModel(nn.Module):
def __init__(self):
super(NewModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
self.fc1 = nn.Linear(128 * 30 * 30, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(-1, 128 * 30 * 30)
x = self.fc1(x)
x = self.fc2(x)
return x
现在我想把这些权重转移到一个新的模型中,这个模型的结构与之前的模型相同;但我只想在展平操作之前使用权重;之后,我想根据我的新任务重新训练线性层。
我的目标:
# Load the saved weights from the trained model
trained_model_path = "/content/model_weights.pth"
trained_model_state_dict = torch.load(trained_model_path)
# Create the new model
new_model = Classifier_model()
# Remove the keys corresponding to the layers that you don't want to initialize
new_model_state_dict = new_model.state_dict()
for key in list(new_model_state_dict.keys()):
if key.startswith('fc'):
new_model_state_dict[key]
# Update the state dictionary with the learned weights
trained_model_state_dict.update(new_model_state_dict)
new_model.load_state_dict(trained_model_state_dict)
但这显示了错误,因为我不能只删除fc层对应的权重,而我需要用随机数替换它们。
有人能建议我怎么做吗?
1条答案
按热度按时间kmbjn2e31#
您只需加载已训练模型的状态字典,然后弹出以
'fc'
开头的键:然后,您可以继续使用
load_state_dict
在新初始化的模型上加载此部分状态字典,并将严格参数设置为False
:它不会对丢失的键抛出错误,而是提供一个丢失的和不期望的键的列表。