我想实现一个基于平衡的准确性评分的模型检查点回调。为此,我实现了以下类:
class BalAccScore(keras.callbacks.Callback):
def __init__(self, validation_data=None):
super(BalAccScore, self).__init__()
self.validation_data = validation_data
def on_train_begin(self, logs={}):
self.balanced_accuracy = []
def on_epoch_end(self, epoch, logs={}):
y_predict = tf.argmax(self.model.predict(self.validation_data[0]), axis=1)
y_true = tf.argmax(self.validation_data[1], axis=1)
balacc = balanced_accuracy_score(y_true, y_predict)
self.balanced_accuracy.append(round(balacc,6))
logs["val_bal_acc"] = balacc
keys = list(logs.keys())
print("\n ------ validation balanced accuracy score: %f ------\n" %balacc)
然后定义以下回调
balAccScore = BalAccScore(validation_data=(X_2, y_2))
mc = ModelCheckpoint(filepath=callback_path, monitor="val_bal_acc", verbose=1, save_best_only=True, save_freq='epoch')
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['val_bal_acc'])
history = model.fit(X_1, y_1, epochs = 5, batch_size = 512,
callbacks=[balAccScore, mc],
validation_data = (X_2, y_2)
)
然后得到错误ValueError: Unknown metric function: val_bal_acc
尽管事实上我在使用accuracy时发现它在history下,例如,通过在编译时设置metrics=[“acc”]。在这种情况下,我会得到预期的警告:
WARNING:tensorflow:Can save best model only with val_bal_acc available, skipping.
但除此之外,模型运行完美。不知道为什么它不运行,否则。
2条答案
按热度按时间7y4bm7vi1#
你应该在compile中去掉引号:
或者至少在R中是这样的
bnl4lu3b2#
您收到该错误是因为在编译模型时没有将
balanced_accuracy_score
作为metrics参数的值传递。在编译模型时传递到metrics参数中的字符串“val_bal_acc”不起作用,因为它不是已知的指标。您可以通过指标的字符串名称访问指标。仅适用于那些已经在tf.keras.metrics
中实现的指标。如果要在培训期间监视验证平衡的准确性,则应实现自定义指标类(你可以在这里看到如何做),然后将它传递给metrics
参数。一旦你完成了这个操作,你就可以使用你给它的前缀为'val'的名称来监视你的自定义指标,如果您想在验证期间监控它。不需要像您那样进行补充的自定义回调,一旦您定义了指标,日志就会自动更新。对于这种特殊情况,您可以在此问题的答案中找到此度量的一些实现。如果你更喜欢回调方法,你不需要定义一个自定义指标,而是利用已经记录的指标,你可以在我的答案here中找到它的实现。