Keras自定义训练循环与V1兼容

wgxvkvu9  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(151)

**在tensorflow 2中向Keras添加自定义训练循环的正确方法是什么,但兼容V1?**为了解释为什么需要这样做,我熟悉在现代Keras模型中重载train_step()方法。然而,我正在做一个我在tensorflow 2之前开始的项目,它不支持这种方法。我能够升级并让我的代码在新版本中再次工作。但是,我遇到了与以下问题相关的严重性能和内存问题:

我尝试了这些问题和其他地方建议的所有技巧,但我没有达到与在兼容模式下运行代码相同的性能。我做这个的时候

tf.compat.v1.disable_eager_execution()

区别在于性能上的两个因素,以及导致我耗尽RAM的类似内存泄漏的症状(我在CPU上运行)。我真的需要使用兼容模式。不幸的是,当我在tensorflow 2中使用兼容模式时,模型不再在我的tf.keras.Model对象中调用train_step(),也不使用我的自定义训练。
这让我不禁要问:**如何在兼容tensorflow 1的Keras模型中实现自定义训练?**具体来说,我想做的自定义训练类型是添加一个软约束,在这里我评估问题域中的点并评估额外的损失项。这些点应该是随机选择的,不需要在训练集中。这看起来像下面这样。

def train_step(self, data):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
    x, y = data

    # Make inputs for the soft constraint
    b = self.bounds  # Numpy array defining boundaries of the input variables
    x0 = (np.random.random((b.shape[1], self.n)) * (b[1] - b[0])[:, None] + b[0][:, None]).T

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Calculate constraint loss
        loss += self.lambda * constraint(self(x0, training=True))

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y, y_pred)

    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

我已经研究了损失层和附加损失函数,但这些似乎不允许我在任意的额外点上评估模型。

kulphzqa

kulphzqa1#

我猜你的内存问题与tensorflow 1的向后兼容性没有直接关系,而是与tensorflow 2的已知内存泄漏问题有关:例如参见这些link1link2
解决方法是,在超参数搜索的每个训练会话结束时,清除tensorflow会话,然后再次重新编译模型。

import gc
from tensorflow.keras import backend as K
...
K.clear_session()
gc.collect()

相关问题