keras 如果验证准确性比上一个时期有所提高,如何编写自定义回调以在每个时期保存模型

wztqucjr  于 2022-11-24  发布在  其他
关注(0)|答案(2)|浏览(102)

下面是我编写的自定义回调函数,但它不起作用:

class bestval(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.history={'loss': [],'acc': [],'val_loss': [],'val_acc': []}

    def on_epoch_end(self, epoch, logs={}):
        #appending val_acc in history
        if logs.get('val_acc', -1) != -1:
            self.history['val_acc'].append(logs.get('val_acc'))
        # Trying to compare current epoch val_acc with all the values in self.history['val_acc']
        if logs.get('val_acc')> [i for i in self.history['val_acc']]:
            filepath="model_save/weights-{epoch:02d}-{val_acc:.4f}.hdf5"
            # Saving the model using TF built-in callback 
            checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(filepath=filepath, 
            monitor='val_acc',  verbose=1, mode='auto')
bestobj= bestval()

拟合模型:

model.fit(xtr,ytr, epochs=4, validation_data=(xte,yte), batch_size=128, callbacks=[bestobj])

当我运行以上我得到下面的错误:
ValueError:具有多个元素的数组的真值不明确。请使用.any()或.all()
我知道我在做一些愚蠢的事情,但我不知道如何解决。任何帮助将不胜感激。

6pp0gazn

6pp0gazn1#

我猜错误是在下面的行中,您正在尝试将值与列表进行比较。if logs.get('val_acc')> [i for i in self.history['val_acc']]:
尝试,for i in self.history['val_acc']: if logs.get('val_acc')>i: #your code

ajsxfq5m

ajsxfq5m2#

代替

if logs.get('val_acc') > [i for i in self.history['val_acc']]

用途

if any(logs.get('val_acc')> val for val in self.history['val_acc'])

相关问题