如何使用以前的检查点在新数据上重新训练基于pytorch-lightning的模型

olhwl3o2  于 2023-02-19  发布在  其他
关注(0)|答案(1)|浏览(288)

我正在使用pytorch-forecasting库(基于pytorch-lightning)运行一个时间序列预测的TFT模型。我的训练例程分为三个不同的任务。首先我使用optuna执行HPO,然后我执行训练+验证,最后,使用完整数据进行再训练(没有验证)。
目前,训练+验证和再训练都是使用新模型从头开始,因此运行时间相当长。因此,我尝试利用增量训练来减少整个训练例程的运行时间,在增量训练中,我将从第2阶段加载经过检查点训练的模型,并在第3阶段的较小时期对其进行再训练。
我有一个方法fit_model(),它在训练/验证和再训练中都使用,但参数不同,fit()的核心部分如下所示:

def fit_model(self, **kwargs):
    ...
    to_retrain = kwargs.get('to_retrain', False)
    ckpt_path = kwargs.get('ckpt_path', None)

    trainer = self._get_trainer(cluster_id, gpu_id, to_retrain)   # returns a pl.Trainer object 
    tft_lightning_module = self._prepare_for_training(cluster_id, to_retrain)

    train_dtloaders = ...
    val_dtloaders = ...

    if not to_retrain:
        trainer.fit(
            tft_lightning_module,
            train_dataloaders=train_dtloaders,
            val_dataloaders=val_dtloaders
        )
    else:
        trainer.fit(
            tft_lightning_module,
            train_dataloaders=train_dtloaders,
            val_dataloaders=val_dtloaders,
            ckpt_path=ckpt_path
        )

    best_model_path = trainer.checkpoint_callback.best_model_path    
    return best_model_path

当我在重新训练阶段调用上述方法时,我可以看到日志,其中显示正在加载检查点模型:

Restored all states from the checkpoint file at /tft/incremental_training/tft_training_20230206/171049/lightning_logs_3/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt

但遗憾的是,在第3阶段没有进行进一步的训练。如果我查看该方法返回的best_model_path,它具有来自训练/验证阶段而不是再训练阶段的旧检查点路径。如何解决此问题?
我正在使用以下库

pytorch-lightning==1.6.5
pytorch-forecasting==0.9.0
g6ll5ycj

g6ll5ycj1#

我终于让它工作了,这里要记住的关键是,在训练和再训练中,不要使用相同的历元数,如果我们训练x个历元,并打算再运行y个历元,那么max-epochs必须设置为x+y,而不是再训练中的y。

相关问题