python-3.x 在自定义数据集上训练微调的TrOCR时,损失和CER增加

mnemlml8  于 2023-05-23  发布在  Python
关注(0)|答案(1)|浏览(416)

我想在我的自定义收据数据集上训练TrOCR。因为我们要对收据使用OCR,所以我们选择了“打印”的微调模型。我们使用5000个边界框的数据集,每个边界框包含一个单词。然而,我们的经验是,所有指标(cer,精度)和损失对于我们运行的每个时期都在恶化。我们无法弄清楚为什么模型在每个时期的表现都更差。
处理器、模型和优化器如下所示:

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

培训方式:

for epoch in range(self.epochs):
    self.model.train()
    train_loss = 0.0
    for batch in tqdm(self.train_dataloader):
        for k, v in batch.items():
            batch[k] = v.to(self.device)
        outputs = self.model(**batch)
        loss = outputs.loss
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        train_loss += loss.item()

有人知道我们做错了什么吗?
在开箱即用评估模型时,它表现良好,但我们希望继续在收据上进行训练,以改进模型。

iyzzxitl

iyzzxitl1#

您选择了已经微调过的模型。“microsoft/trocr-base-printed”已经在SROIE数据集上进行了微调。所以没有一点微调已经微调过的模型。相反,只选择预训练模型,如trocr-base-stage1或trocr-small-stage1。

相关问题