下面是我编写的自定义回调函数,但它不起作用:
class bestval(tf.keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.history={'loss': [],'acc': [],'val_loss': [],'val_acc': []}
def on_epoch_end(self, epoch, logs={}):
#appending val_acc in history
if logs.get('val_acc', -1) != -1:
self.history['val_acc'].append(logs.get('val_acc'))
# Trying to compare current epoch val_acc with all the values in self.history['val_acc']
if logs.get('val_acc')> [i for i in self.history['val_acc']]:
filepath="model_save/weights-{epoch:02d}-{val_acc:.4f}.hdf5"
# Saving the model using TF built-in callback
checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(filepath=filepath,
monitor='val_acc', verbose=1, mode='auto')
bestobj= bestval()
拟合模型:
model.fit(xtr,ytr, epochs=4, validation_data=(xte,yte), batch_size=128, callbacks=[bestobj])
当我运行以上我得到下面的错误:
ValueError:具有多个元素的数组的真值不明确。请使用.any()或.all()
我知道我在做一些愚蠢的事情,但我不知道如何解决。任何帮助将不胜感激。
2条答案
按热度按时间6pp0gazn1#
我猜错误是在下面的行中,您正在尝试将值与列表进行比较。
if logs.get('val_acc')> [i for i in self.history['val_acc']]:
尝试,
for i in self.history['val_acc']: if logs.get('val_acc')>i: #your code
ajsxfq5m2#
代替
用途