keras进度条可以显示即时指标而不是运行平均值吗?

djmepvbi  于 2022-11-13  发布在  其他
关注(0)|答案(2)|浏览(141)

我所说的“进度条”是指显示为tf.keras.Model.fit的标准进度条
据我所知,它显示了所选指标的运行平均值(在当前时期内),但我希望它显示最后一次完成迭代时的值。
是否有一个内置的方法来进行这种改变?如果没有,最简单的方法是什么?

odopli94

odopli941#

我前一段时间打了个电话来解决这个问题。

class print_on_end(Callback):
  def on_batch_end(self, batch, logs={}):
    print()

你想这样称呼它。

model.fit(training_dataset, steps_per_epoch=num_training_samples, epochs=EPOCHS,validation_data=validation_dataset, callbacks=[print_on_end()])

但是这个回调打印的是平均值,只是在不同的行上,所以我不认为这是你想要的。
这反而:

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))
            )
        )

这个回调打印了每一批的损失,所以它应该是你要找的。
(如果需要度量,只需将logs["loss"]更改为logs["name of the metric"],例如logs["mean_absolute_error"]
编辑:
要检查日志中的度量名称,您可以打印日志的键并找到您正在搜索的ONA。

class PrintKeys(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print(keys)
            )
        )

在这种方法中,您应该只找到损失和度量的关键字。
来源:https://keras.io/guides/writing_your_own_callbacks/

osh3o9ms

osh3o9ms2#

MeanSquaredError度量的类层次结构的示例如下

MeanSquaredError->MeanMetricWrapper->Mean->Reduce->Metric

有内置的方式吗?
主要问题是所有指标都是Reduce指标的子类,Reduce指标执行聚合,并且没有预见到要更改Reduce基类的行为。
如何实现这一点最容易
给定上面的模式,您可以通过创建MeanMetricWrapper的新指标子类来实现您想要的结果,该子类通过首先调用self.reset_state,然后调用MeanMetricWrapper.update_state来覆盖MeanMetricWrapper的update_state方法。这样,底层Reduce基类中的聚合将只聚合一个值。工作示例如下:

#! /usr/bin/env python
import numpy as np
import keras
from keras.metrics import MeanMetricWrapper

x=np.linspace(0, 1, 20000)[:,np.newaxis,np.newaxis]
y=np.sin(x*2*np.pi)

model = keras.Sequential()
model.add(keras.layers.Dense(4, activation="tanh", input_shape=(1,1)))
model.add(keras.layers.Dense(4, activation="tanh"))
model.add(keras.layers.Dense(4))

#####
# Here the Instantaneous metric variant 
class InstMetric(MeanMetricWrapper):
    def __init__(self, fn, **kwargs):
        """ fn is the callable loss function you want to use in your metric """
        super().__init__(fn=fn, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.reset_states()
        return super().update_state(y_true, y_pred, sample_weight=sample_weight)
#####

model.compile(optimizer='adam', loss='mean_squared_error',
              metrics=[
                 keras.metrics.MeanSquaredError(name="MSE"),
                 InstMetric(keras.metrics.mean_squared_error, name="IMSE")
              ]
)

model.fit(x=x, y=y, epochs=1, batch_size=5, steps_per_epoch=1000)

将此脚本存储为inst_demo.py,并通过tr运行它,以展开终端中的进度条

$> ./inst_demo.py | tr \\r \\n 

   1/1000 [..............................] - ETA: 8:07 - loss: 0.4656 - MSE: 0.4656 - IMSE: 0.4656
  42/1000 [>.............................] - ETA: 1s - loss: 0.4874 - MSE: 0.4874 - IMSE: 0.4133  
  87/1000 [=>............................] - ETA: 1s - loss: 0.4685 - MSE: 0.4685 - IMSE: 0.4764
 132/1000 [==>...........................] - ETA: 1s - loss: 0.4627 - MSE: 0.4627 - IMSE: 0.5445
 175/1000 [====>.........................] - ETA: 0s - loss: 0.4558 - MSE: 0.4558 - IMSE: 0.7689
 217/1000 [=====>........................] - ETA: 0s - loss: 0.4443 - MSE: 0.4443 - IMSE: 0.1058
 264/1000 [======>.......................] - ETA: 0s - loss: 0.4258 - MSE: 0.4258 - IMSE: 0.4162
 311/1000 [========>.....................] - ETA: 0s - loss: 0.4090 - MSE: 0.4090 - IMSE: 0.1716
 356/1000 [=========>....................] - ETA: 0s - loss: 0.3889 - MSE: 0.3889 - IMSE: 0.3417
 400/1000 [===========>..................] - ETA: 0s - loss: 0.3707 - MSE: 0.3707 - IMSE: 0.1271
 445/1000 [============>.................] - ETA: 0s - loss: 0.3532 - MSE: 0.3532 - IMSE: 0.0729
 489/1000 [=============>................] - ETA: 0s - loss: 0.3383 - MSE: 0.3383 - IMSE: 0.2310
 535/1000 [===============>..............] - ETA: 0s - loss: 0.3248 - MSE: 0.3248 - IMSE: 0.1228
 580/1000 [================>.............] - ETA: 0s - loss: 0.3143 - MSE: 0.3143 - IMSE: 0.2670
 625/1000 [=================>............] - ETA: 0s - loss: 0.3048 - MSE: 0.3048 - IMSE: 0.1762
 671/1000 [===================>..........] - ETA: 0s - loss: 0.2962 - MSE: 0.2962 - IMSE: 0.0751
 715/1000 [====================>.........] - ETA: 0s - loss: 0.2896 - MSE: 0.2896 - IMSE: 0.0650
 756/1000 [=====================>........] - ETA: 0s - loss: 0.2831 - MSE: 0.2831 - IMSE: 0.2332
 799/1000 [======================>.......] - ETA: 0s - loss: 0.2773 - MSE: 0.2773 - IMSE: 0.1026
 841/1000 [========================>.....] - ETA: 0s - loss: 0.2721 - MSE: 0.2721 - IMSE: 0.1238
 888/1000 [=========================>....] - ETA: 0s - loss: 0.2673 - MSE: 0.2673 - IMSE: 0.1471
 936/1000 [===========================>..] - ETA: 0s - loss: 0.2631 - MSE: 0.2631 - IMSE: 0.2242
 986/1000 [============================>.] - ETA: 0s - loss: 0.2580 - MSE: 0.2580 - IMSE: 0.2704
1000/1000 [==============================] - 2s 1ms/step - loss: 0.2574 - MSE: 0.2574 - IMSE: 0.2773

因此,每次更新进度条时,您都会获得一个即时值。
如果您不想选择要使用的指标,也可以从任何可用的keras指标中导出InstMetric。

相关问题