我想在我的自定义收据数据集上训练TrOCR。因为我们要对收据使用OCR,所以我们选择了“打印”的微调模型。我们使用5000个边界框的数据集,每个边界框包含一个单词。然而,我们的经验是,所有指标(cer,精度)和损失对于我们运行的每个时期都在恶化。我们无法弄清楚为什么模型在每个时期的表现都更差。
处理器、模型和优化器如下所示:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
培训方式:
for epoch in range(self.epochs):
self.model.train()
train_loss = 0.0
for batch in tqdm(self.train_dataloader):
for k, v in batch.items():
batch[k] = v.to(self.device)
outputs = self.model(**batch)
loss = outputs.loss
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
train_loss += loss.item()
有人知道我们做错了什么吗?
在开箱即用评估模型时,它表现良好,但我们希望继续在收据上进行训练,以改进模型。
1条答案
按热度按时间iyzzxitl1#
您选择了已经微调过的模型。“microsoft/trocr-base-printed”已经在SROIE数据集上进行了微调。所以没有一点微调已经微调过的模型。相反,只选择预训练模型,如trocr-base-stage1或trocr-small-stage1。