我正在加载这样的模型:
id2label = {
0: 'background',
1: 'cake',
2: 'donut',
}
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.backbone, id2label=self.id2label, ignore_mismatched_sizes=True)
model.load_state_dict(torch.load('checkpoint.pt', map_location=torch.device('cpu')))
然而,我实际上并不知道id2label
。(我只有检查点)。我并不真正关心类的名称,我只是想知道检查点模型中有多少个类。我可以在出现的警告消息中看到它,但希望避免这种情况:
RuntimeError: Error(s) in loading state_dict for Mask2FormerForUniversalSegmentation:
size mismatch for class_predictor.weight: copying a param with shape torch.Size([8, 256]) from checkpoint, the shape in current model is torch.Size([20, 256]).
size mismatch for class_predictor.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
size mismatch for criterion.empty_weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
1条答案
按热度按时间vc9ivgsu1#
您可以查看state_dict:
唯一的缺点是您需要知道层的名称,但当您只加载
Mask2FormerForUniversalSegmentation
时,这是可行的。