pytorch 你能确定HuggingFace分割模型中输出类的数量吗?

g9icjywg  于 2023-04-06  发布在  其他
关注(0)|答案(1)|浏览(188)

我正在加载这样的模型:

id2label = {
        0: 'background', 
        1: 'cake',
        2: 'donut', 
}
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.backbone, id2label=self.id2label, ignore_mismatched_sizes=True)  
model.load_state_dict(torch.load('checkpoint.pt', map_location=torch.device('cpu')))

然而,我实际上并不知道id2label。(我只有检查点)。我并不真正关心类的名称,我只是想知道检查点模型中有多少个类。我可以在出现的警告消息中看到它,但希望避免这种情况:

RuntimeError: Error(s) in loading state_dict for Mask2FormerForUniversalSegmentation:
    size mismatch for class_predictor.weight: copying a param with shape torch.Size([8, 256]) from checkpoint, the shape in current model is torch.Size([20, 256]).
    size mismatch for class_predictor.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
    size mismatch for criterion.empty_weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
vc9ivgsu

vc9ivgsu1#

您可以查看state_dict

import torch

chk = torch.load(checkpoint.pt)
# chk is a dict[str, torch.tensor]
# The layer shape tells you the number of labels +1 (i.e. subtract 1)
chk["class_predictor.weight"].shape[0]

唯一的缺点是您需要知道层的名称,但当您只加载Mask2FormerForUniversalSegmentation时,这是可行的。

相关问题