python 在机器学习代码中保存训练的权重

8yoxcaq7  于 2023-08-02  发布在  Python
关注(0)|答案(1)|浏览(181)

我有colab用于运行机器学习模型,但当它得到80 epoch时,我的colab ram崩溃了,我无法进入80 epoch。我希望有人能帮助我将训练的权重保存在某个地方,在ram崩溃后,我开始从那个epoch开始训练模型。这是我的代码,我如何在这个python代码中编写目的代码?

for comm_round in range(comms_round):

    global_weights = global_model.get_weights()

    scaled_local_weight_list = list()

    client_names= list(clients_batched.keys())
    random.shuffle(client_names)

    for client in client_names:
        local_model = Transformer
        local_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
                            optimizer=tf.keras.optimizers.Adam(learning_rate = 0.001),
                            metrics='acc')

        local_model.set_weights(global_weights)

        local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])

        scaling_factor = weight_scalling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)

        K.clear_session()

    average_weights = sum_scaled_weights(scaled_local_weight_list)

    global_model.set_weights(average_weights)

    for(X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(test_x, test_y, global_model, comm_round + 1)

字符串
此代码用于最后一步和联邦学习。

toiithl6

toiithl61#

潜在的解决方案

注意:我看到checkpoint_callbackfit函数中,我想这可能是一个模型回调函数?如果是这样的话,我不确定这是否可行,但这是我在使用TensorFlow框架进行训练时,试图保存最佳权重时想到的一般方法。由于这个过程是连续的跨纪元,即使它在一个较晚的纪元失败,我猜最好的运行之前应该已经保存下来。

您可以使用TensorFlow的ModelCheckpoint回调在每个纪元后自动保存模型的权重。然后,如果运行时崩溃,可以加载这些权重以恢复训练。以下是如何设置它:
1.在训练循环之前定义ModelCheckpoint回调:

checkpoint_filepath = '/path/to/checkpoint/directory'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',  # change this to what you want to monitor
    mode='auto', 
    save_best_only=False,
    verbose=1,
)

字符串
1.在fit函数中包括ModelCheckpoint回调:

for comm_round in range(comms_round):
    # ... existing code ...

    for client in client_names:
        # ... existing code ...

        local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])

        # ... existing code ...


在每个时期之后,模型的权重将保存到指定的文件路径。
1.如果运行时崩溃并且需要继续训练,则可以将保存的权重加载到模型中:

model.load_weights(checkpoint_filepath)


请注意,要保存模型权重的目录的路径必须存在。如果您使用的是Google Colab,最好将您的体重保存在Google Drive中。通过这种方式,即使Colab运行时被回收,保存的权重也将保持不变。为此,您需要将Google Drive安装到Colab笔记本电脑上。
另外,请注意,这不会自动保存当前的沟通轮次。您将需要手动管理。例如,您可以在每次保存模型权重时,将当前的通信轮数保存到文件中。当您恢复培训时,您可以从文件中读取此编号,并从那里继续进行沟通。

相关问题