作为训练循环的一部分,每次epoch的验证分数提高时,我都会存储模型的检查点,就像通常所做的那样:
if dice_score > best_dice_score:
model_state = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"score": dice_score
}
self.best_epoch = epoch
best_dice_score = dice_score
torch.save(model_state, "somefile.pt")
它工作正常。但是,它真的很慢。检查点文件约为750 MB,并写入SSD。这大约占用了总列车时间的30%,我想减少这个时间。如何做到这一点?
1条答案
按热度按时间xytpbqjk1#
您是否检查了保存模型的真实的估计时间?如果您使用SSD,写入时间将低于1秒。
其他一些脚本会减慢整个训练时间,例如每个epoch上的dataloader。