如何使用tf.keras.callbacks获取非聚集批处理统计信息?

zhte4eai  于 2022-11-13  发布在  其他
关注(0)|答案(2)|浏览(111)

我正在尝试使用keras model.fit()方法训练一个模型。这个方法返回一个历史对象,其中包含每个时期的损失值-但是我希望每个批次都有损失值。
我在网上找到了一些建议,建议使用一个带有on_batch_end(self, logs={})方法的自定义回调类。问题是这个方法只获得传递的 * 聚合 * 统计信息,这些统计信息在每个时期都会被重置。我希望每个批次都有 * 单独 * 的统计信息。
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_train_batch_end

i7uq4tfw

i7uq4tfw1#

您可以使用自定义训练循环轻松地完成这一任务,在该循环中,您只需将每个批次的损失值附加到一个列表中。

train_loss_per_train_batch.append(loss_value.numpy())

以下是所有操作的方法:

import tensorflow as tf
import tensorflow_datasets as tfds

ds = tfds.load('iris', split='train', as_supervised=True)

train = ds.take(125).shuffle(16).batch(4)
test = ds.skip(125).take(25).shuffle(16).batch(4)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax')
])

loss_object = tf.losses.SparseCategoricalCrossentropy(from_logits=False)

def compute_loss(model, x, y, training):
  out = model(x, training=training)
  loss = loss_object(y_true=y, y_pred=out)
  return loss

def get_grad(model, x, y):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x, y, training=True)
    return loss, tape.gradient(loss, model.trainable_variables)

optimizer = tf.optimizers.Adam()

verbose = "Epoch {:2d} Loss: {:.3f} TLoss: {:.3f} Acc: {:.2%} TAcc: {:.2%}"

train_loss_per_train_batch = list()

for epoch in range(1, 25 + 1):
    train_loss = tf.metrics.Mean()
    train_acc = tf.metrics.SparseCategoricalAccuracy()
    test_loss = tf.metrics.Mean()
    test_acc = tf.metrics.SparseCategoricalAccuracy()

    for x, y in train:
        loss_value, grads = get_grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_loss.update_state(loss_value)
        train_acc.update_state(y, model(x, training=True))
        train_loss_per_train_batch.append(loss_value.numpy())

    for x, y in test:
        loss_value, _ = get_grad(model, x, y)
        test_loss.update_state(loss_value)
        test_acc.update_state(y, model(x, training=False))

    print(verbose.format(epoch,
                         train_loss.result(),
                         test_loss.result(),
                         train_acc.result(),
                         test_acc.result()))
ippsafx7

ippsafx72#

根据提供的平均损耗,可以计算出当前批次的损耗,如下所示:

from tensorflow.keras.callbacks import Callback

class CustomCallback(Callback):
    ''' This callback converts the average loss (default behavior in TF>=2.2)
        into the loss for only the current batch.
    '''
    def on_epoch_begin(self, epoch, logs={}):
        self.previous_loss_sum = 0

    def on_train_batch_end(self, batch, logs={}):
        # calculate loss of current batch:
        current_loss_sum =  (batch + 1) * logs['loss']
        current_loss = current_loss_sum - self.previous_loss_sum
        self.previous_loss_sum = current_loss_sum

        # use current_loss:
        # ...

此代码可以添加到任何需要当前批次损失而不是平均损失的自定义回调中。
此外,如果您使用的是Tensorflow 1或TensorFlow 2版本〈= 2.1,则不要在回调中包含此代码,因为在这些版本中,已经提供了当前损耗,而不是平均损耗。

相关问题