pytorch 在中断事件中保存和恢复Skorch GridSearchCV

gev0vcfq  于 2023-05-29  发布在  其他
关注(0)|答案(1)|浏览(267)

我正在与Skorch合作,并使用GridSearchCV执行网格搜索。然而,我担心如果在搜索过程中发生意外事件(如系统故障或中断)会发生什么。在这种情况下,我想保存模型的进展,并从我停止的地方继续网格搜索。
我尝试利用Skorch中的检查点回调来实现这一目的。但是,我不确定在Skorch中正确保存和加载模型状态的正确方法。谁能提供一个全面的例子或指导我实现这一点?

6l7fqoea

6l7fqoea1#

您是否检查/尝试过在网格搜索期间使用检查点回调来保存和加载模型的状态?
下面是一个简单的例子,因为你没有提供任何代码:

from skorch.callbacks import Checkpoint
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from torch import nn

# simple neural network classifier
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(20, 2)

    def forward(self, x):
        return self.fc(x)

# Skorch NeuralNetClassifier
net = NeuralNetClassifier(
    Net,
    max_epochs=10,
    lr=0.1,
    callbacks=[Checkpoint(monitor='valid_acc_best', f_params='best_model.pt')],
)

# fake data
X, y = make_classification(n_samples=100, n_features=20, random_state=42)

# grid search params
param_grid = {
    'lr': [0.1, 0.01, 0.001],
    'module__hidden_units': [10, 20, 30],
}

# Here we can use Checkpoint callback to monitor the search
gs = GridSearchCV(net, param_grid, scoring='accuracy', cv=3, refit=True)
gs.fit(X, y)

要加载保存的模型,请使用用途:

best_model = Net()
best_model.load_state_dict(torch.load('best_model.pt'))

相关问题