我正在使用pytorch-forecasting库(基于pytorch-lightning)运行一个时间序列预测的TFT模型。我的训练例程分为三个不同的任务。首先我使用optuna执行HPO,然后我执行训练+验证,最后,使用完整数据进行再训练(没有验证)。
目前,训练+验证和再训练都是使用新模型从头开始,因此运行时间相当长。因此,我尝试利用增量训练来减少整个训练例程的运行时间,在增量训练中,我将从第2阶段加载经过检查点训练的模型,并在第3阶段的较小时期对其进行再训练。
我有一个方法fit_model()
,它在训练/验证和再训练中都使用,但参数不同,fit()的核心部分如下所示:
def fit_model(self, **kwargs):
...
to_retrain = kwargs.get('to_retrain', False)
ckpt_path = kwargs.get('ckpt_path', None)
trainer = self._get_trainer(cluster_id, gpu_id, to_retrain) # returns a pl.Trainer object
tft_lightning_module = self._prepare_for_training(cluster_id, to_retrain)
train_dtloaders = ...
val_dtloaders = ...
if not to_retrain:
trainer.fit(
tft_lightning_module,
train_dataloaders=train_dtloaders,
val_dataloaders=val_dtloaders
)
else:
trainer.fit(
tft_lightning_module,
train_dataloaders=train_dtloaders,
val_dataloaders=val_dtloaders,
ckpt_path=ckpt_path
)
best_model_path = trainer.checkpoint_callback.best_model_path
return best_model_path
当我在重新训练阶段调用上述方法时,我可以看到日志,其中显示正在加载检查点模型:
Restored all states from the checkpoint file at /tft/incremental_training/tft_training_20230206/171049/lightning_logs_3/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt
但遗憾的是,在第3阶段没有进行进一步的训练。如果我查看该方法返回的best_model_path
,它具有来自训练/验证阶段而不是再训练阶段的旧检查点路径。如何解决此问题?
我正在使用以下库
pytorch-lightning==1.6.5
pytorch-forecasting==0.9.0
1条答案
按热度按时间g6ll5ycj1#
我终于让它工作了,这里要记住的关键是,在训练和再训练中,不要使用相同的历元数,如果我们训练x个历元,并打算再运行y个历元,那么max-epochs必须设置为x+y,而不是再训练中的y。