我正在尝试使用keras model.fit()
方法训练一个模型。这个方法返回一个历史对象,其中包含每个时期的损失值-但是我希望每个批次都有损失值。
我在网上找到了一些建议,建议使用一个带有on_batch_end(self, logs={})
方法的自定义回调类。问题是这个方法只获得传递的 * 聚合 * 统计信息,这些统计信息在每个时期都会被重置。我希望每个批次都有 * 单独 * 的统计信息。
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_train_batch_end
2条答案
按热度按时间i7uq4tfw1#
您可以使用自定义训练循环轻松地完成这一任务,在该循环中,您只需将每个批次的损失值附加到一个列表中。
以下是所有操作的方法:
ippsafx72#
根据提供的平均损耗,可以计算出当前批次的损耗,如下所示:
此代码可以添加到任何需要当前批次损失而不是平均损失的自定义回调中。
此外,如果您使用的是Tensorflow 1或TensorFlow 2版本〈= 2.1,则不要在回调中包含此代码,因为在这些版本中,已经提供了当前损耗,而不是平均损耗。