python 尝试加载PyTorch模型时出现大小不匹配的运行时错误

piv4azn7  于 2023-03-21  发布在  Python
关注(0)|答案(2)|浏览(373)

下面是我尝试运行的代码。fasterrcnn_foodtracker.pth是我尝试用PyTorch加载的已经训练好的模型。

import torch
import torchvision
import cv2

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

path = '/home/amir/PycharmProjects/Food-Recognition/fasterrcnn_foodtracker.pth'
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
model.eval()

img = cv2.imread('twodishes.jpg')
prediction = model([img])
print(prediction)

出现大小不匹配的运行时错误。

RuntimeError: Error(s) in loading state_dict for FasterRCNN:
    size mismatch for roi_heads.box_predictor.cls_score.weight: copying a param with shape torch.Size([100, 1024]) from checkpoint, the shape in current model is torch.Size([91, 1024]).
    size mismatch for roi_heads.box_predictor.cls_score.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for roi_heads.box_predictor.bbox_pred.weight: copying a param with shape torch.Size([400, 1024]) from checkpoint, the shape in current model is torch.Size([364, 1024]).
    size mismatch for roi_heads.box_predictor.bbox_pred.bias: copying a param with shape torch.Size([400]) from checkpoint, the shape in current model is torch.Size([364]).
xggvc2p6

xggvc2p61#

在我看来,您的模型配置与模型检查点的内容不匹配。我想象您的模型具有诸如input_sizeoutput_size(或num_classes)之类的参数,这些参数最终将定义您的模型的外观(每层的参数数量等)。
简单示例:

# two layer linear 'network', `num_in` input units, `num_out` output units
model = nn.Sequential(nn.Linear(num_in, 100), nn.Linear(100, num_out))

如果你为num_in = 10num_out = 20训练并保存这个模型,将这些参数更改为num_in = 12/num_out = 22,并加载你之前保存的模型,加载例程将抱怨形状不匹配(10对12和20对22)。
这似乎是发生在你身上的事情。解决方案:您需要确保使用与用于生成正在加载的检查点的模型相同的超参数来初始化模型。

w8f9ii69

w8f9ii692#

尝试加载为pretrained=False,然后更新状态指令。这是因为RCNN模型是在x数量的类上训练的,但您正在尝试匹配来自另一个模型的权重,该模型可能是在不同数量的类上训练的。

相关问题