pytorch 加速torch.save

weylhg0b  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(146)

作为训练循环的一部分,每次epoch的验证分数提高时,我都会存储模型的检查点,就像通常所做的那样:

if dice_score > best_dice_score:
    model_state = {
        "epoch": epoch,
        "model_state_dict": self.model.state_dict(),
        "score": dice_score
    }

    self.best_epoch = epoch
    best_dice_score = dice_score

    torch.save(model_state, "somefile.pt")

它工作正常。但是,它真的很慢。检查点文件约为750 MB,并写入SSD。这大约占用了总列车时间的30%,我想减少这个时间。如何做到这一点?

xytpbqjk

xytpbqjk1#

您是否检查了保存模型的真实的估计时间?如果您使用SSD,写入时间将低于1秒。
其他一些脚本会减慢整个训练时间,例如每个epoch上的dataloader。

相关问题