从 flask 中的.h5文件加载keras(tensorflow)模型时出错

bqf10yzr  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(173)

当我运行文件时,我得到这些错误:

python app.py
2023-09-03 22:09:58.412966: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE SSE2 SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
  File "C:\Users\Osama\python enviroment\app.py", line 95, in <module>
    model = load_local_model()
            ^^^^^^^^^^^^^^^^^^
  File "C:\Users\Osama\python enviroment\app.py", line 29, in load_local_model
    model = load_model(model_path)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Osama\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\saving\saving_api.py", line 238, in load_model
    return legacy_sm_saving_lib.load_model(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Osama\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\utils\traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\Osama\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\utils\generic_utils.py", line 102, in func_load
    code = marshal.loads(raw_code)
           ^^^^^^^^^^^^^^^^^^^^^^^
ValueError: bad marshal data (unknown type code)

下面是完整的代码:

import os
import uuid
import requests
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
import json

load_dotenv()

app = Flask(__name__)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
model = None

ALLOWED_EXT = {'jpg', 'jpeg', 'png', 'jfif'}
classes = ['Meningioma', 'Glioma', 'Pituitary']

def allowed_file(filename: str) -> bool:
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXT

def load_local_model() -> tf.keras.Model:
    model_path = os.path.join(BASE_DIR, 'model.h5')
    model = load_model(model_path)
    return model

def predict(filename: str, model) -> tuple[list[str], list[float]]:
    img = load_img(filename, target_size=(256, 256))
    img = img_to_array(img)
    img = img.reshape(1, 256, 256, 3)

    img = img.astype('float32')
    img = img / 255.0
    result = model.predict(img)

    dict_result = {}
    for i in range(len(classes)):
        dict_result[result[0][i]] = classes[i]

    res = result[0]
    res.sort()
    res = res[::-1]
    prob = res[:3]

    prob_result = []
    class_result = []
    for i in range(len(prob)):
        prob_result.append(round(prob[i] * 100, 3))
        class_result.append(dict_result[prob[i]])

    return class_result, prob_result

@app.route('/predict', methods=['POST'])
def predict_image():
    if 'file' not in request.files:
        return jsonify({'error': 'No file found in the request'})

    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No file selected'})

    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        unique_filename = str(uuid.uuid4())
        img_path = os.path.join(
            BASE_DIR, 'static/images', unique_filename + '.jpg')
        file.save(img_path)

        class_result, prob_result = predict(img_path, model)

        predictions = {
            "class1": class_result[0],
            "class2": class_result[1],
            "class3": class_result[1],
            "prob1": prob_result[0],
            "prob2": prob_result[1],
            "prob3": prob_result[2],
        }

        return jsonify(predictions)
        return jsonify({'message': 'serve is working'})
    else:
        return jsonify({'error': 'Invalid file format'})

if __name__ == "__main__":
    # Load model locally
    model = load_local_model()
    print("Model loaded.")

    app.run(host='0.0.0.0', port=int(os.getenv('PORT', 5000)), debug=False)

这里是GitHub repo,如果你想看看文件结构。
我真的没有使用Python的经验,因为我是一个Web前端开发人员,所以可能代码有问题,但我已经尝试了一切:ChatGpt,Bing,当然还有在Stackoverflow上花了几个小时。
谢谢你的时间,谢谢。

8yoxcaq7

8yoxcaq71#

这个问题是由于在我的Flask应用程序中使用不兼容的版本造成的。我是这样解决的:

  • TensorFlow版本不匹配:
  • 该模型使用TensorFlow v2.12.0进行训练,但我在Flask应用程序中安装了TensorFlow v2.13.0。
  • 为了解决这个问题,我降级了我的TensorFlow版本,以匹配训练期间使用的版本。您可以通过运行以下命令来执行此操作:
pip install tensorflow==2.12.0
  • Python版本不匹配:
  • 该模型是使用Python v3.10.9构建的,而我在Flask应用程序中使用的是Python v3.11.3。
  • 为了确保兼容性,我通过创建虚拟环境并安装Python v3.10.9切换到正确的Python版本。步骤如下:
  • 安装virtualenv如果你还没有它:
pip install virtualenv
  • 创建虚拟环境:
virtualenv myenv
  • 激活虚拟环境:
  • 对于Windows:
myenv\Scripts\activate
  • 对于Linux/macOS:
source myenv/bin/activate
  • 在虚拟环境中安装Python v3.10.9:
pip install python==3.10.9

在应用这些更改后,我的Flask应用程序中的TensorFlow版本和Python版本与模型兼容。这解决了我的问题,我希望它也能帮助你。
请注意,确保模型、TensorFlow和Python之间的版本兼容性非常重要,以避免将来出现此类问题。

相关问题