如何保存和重新加载Tensorflow/Keras和Keras-cv YOLO模型

yrdbyhpb  于 2024-01-08  发布在  其他
关注(0)|答案(1)|浏览(133)

我一直在通过Keras网站的例子:https://keras.io/examples/vision/yolov8/在Tensorflow/Keras中构建YOLOv 8模型。我已经成功训练了一个模型,尽管我在使用回调时遇到了错误,所以我删除了回调,并尝试在模型训练完成后手动保存模型。

yolo.fit(
    train_ds,
    validation_data=val_ds,
    epochs=3
)
yolo.save('my_yolo_mdl.keras')

yolo_load=tf.keras.models.load_model('my_yolo_mdl.keras')

字符串
我在这里得到一个警告,说模型没有被编译,所以我用和训练之前一样的方式编译它:

optimizer = tf.keras.optimizers.legacy.Adam(
    learning_rate=LEARNING_RATE,
    global_clipnorm=GLOBAL_CLIPNORM,
)

yolo_load.compile(
    optimizer=optimizer, classification_loss="binary_crossentropy", box_loss="ciou"
)


然后,当我尝试使用编译后的模型和形状为(1,:,:,3)的图像进行预测时,我得到了一个错误:

yolo_load.predict(img)

TypeError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2440, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2425, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2413, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py", line 616, in predict_step
        return self.decode_predictions(outputs, args[-1])
    File "/usr/local/lib/python3.10/dist-packages/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py", line 609, in decode_predictions
        return self.prediction_decoder(box_preds, scores)

    TypeError: '_DictWrapper' object is not callable


如果我使用原始模型,而不是加载的模型,此输入数组将按预期工作。我哪里出错了?谢谢

qcbq4gxm

qcbq4gxm1#

编辑:不工作。
在load_model函数中,将compile参数设置为false

yolo_load=tf.keras.models.load_model('my_yolo_mdl.keras', compile=False)

字符串
然后按照第二个代码块中所示的方式再次编译它。

相关问题