Keras自定义训练循环与V1兼容

wgxvkvu9  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(150)

**在tensorflow 2中向Keras添加自定义训练循环的正确方法是什么,但兼容V1?**为了解释为什么需要这样做,我熟悉在现代Keras模型中重载train_step()方法。然而,我正在做一个我在tensorflow 2之前开始的项目,它不支持这种方法。我能够升级并让我的代码在新版本中再次工作。但是,我遇到了与以下问题相关的严重性能和内存问题:

我尝试了这些问题和其他地方建议的所有技巧,但我没有达到与在兼容模式下运行代码相同的性能。我做这个的时候

  1. tf.compat.v1.disable_eager_execution()

区别在于性能上的两个因素,以及导致我耗尽RAM的类似内存泄漏的症状(我在CPU上运行)。我真的需要使用兼容模式。不幸的是,当我在tensorflow 2中使用兼容模式时,模型不再在我的tf.keras.Model对象中调用train_step(),也不使用我的自定义训练。
这让我不禁要问:**如何在兼容tensorflow 1的Keras模型中实现自定义训练?**具体来说,我想做的自定义训练类型是添加一个软约束,在这里我评估问题域中的点并评估额外的损失项。这些点应该是随机选择的,不需要在训练集中。这看起来像下面这样。

  1. def train_step(self, data):
  2. # Unpack the data. Its structure depends on your model and
  3. # on what you pass to `fit()`.
  4. x, y = data
  5. # Make inputs for the soft constraint
  6. b = self.bounds # Numpy array defining boundaries of the input variables
  7. x0 = (np.random.random((b.shape[1], self.n)) * (b[1] - b[0])[:, None] + b[0][:, None]).T
  8. with tf.GradientTape() as tape:
  9. y_pred = self(x, training=True) # Forward pass
  10. # Compute the loss value
  11. # (the loss function is configured in `compile()`)
  12. loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
  13. # Calculate constraint loss
  14. loss += self.lambda * constraint(self(x0, training=True))
  15. # Compute gradients
  16. trainable_vars = self.trainable_variables
  17. gradients = tape.gradient(loss, trainable_vars)
  18. # Update weights
  19. self.optimizer.apply_gradients(zip(gradients, trainable_vars))
  20. # Update metrics (includes the metric that tracks the loss)
  21. self.compiled_metrics.update_state(y, y_pred)
  22. # Return a dict mapping metric names to current value
  23. return {m.name: m.result() for m in self.metrics}

我已经研究了损失层和附加损失函数,但这些似乎不允许我在任意的额外点上评估模型。

kulphzqa

kulphzqa1#

我猜你的内存问题与tensorflow 1的向后兼容性没有直接关系,而是与tensorflow 2的已知内存泄漏问题有关:例如参见这些link1link2
解决方法是,在超参数搜索的每个训练会话结束时,清除tensorflow会话,然后再次重新编译模型。

  1. import gc
  2. from tensorflow.keras import backend as K
  3. ...
  4. K.clear_session()
  5. gc.collect()

相关问题