pytorch num_labels实际上是做什么的?

bihw5rsg  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(243)

当训练基于BERT的模型时,可以设置num_labels
AutoConfig.from_pretrained(BERT_MODEL_NAME, num_labels=num_labels)
因此,例如,如果我们想要预测3个值,我们可以使用num_labels=3
我的问题是它在内部做什么?它只是把一个nn.Linear连接到最后一个嵌入层吗?
谢谢

h5qlskok

h5qlskok1#

我假设如果有一个num标签,那么模型就用于分类,然后您可以简单地转到BERT关于拥抱脸的文档,然后搜索分类类并查看代码,然后您会发现以下内容:https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/bert/modeling_bert.py#L1572

if labels is not None:
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        if self.config.problem_type == "regression":
            loss_fct = MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)

所以我们看到的标签数量影响损失函数的使用
我希望这能回答你的问题

相关问题