如何在MLFlow中记录已经训练好的Tensorflow模型?

y3bcpkx1  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(133)

我正在尝试记录一个Tensorflow模型,该模型已经加载了model = tf.saved_model.load(my_model_directory)。当我调用mlflow.tensorflow.log_model(model, "model")时,我收到一个错误
在MLFlow FAQ中,他们有一个用Keras模型做同样事情的例子。

import mlflow
import tensorflow as tf

model = tf.keras.models.load_model("tf_keras_model")
with mlflow.start_run() as run:
    mlflow.keras.log_model(model, "model")

我用Tensorflow模型试过:

imported: AutoTrackable = tf.saved_model.load(my_model_directory)
inference_function: WrappedFunction = imported.signatures["serving_default"]

mlflow.set_tracking_uri("http://127.0.0.1:5000")
with mlflow.start_run() as run:
    #mlflow.tensorflow.log_model(imported, "model")
    mlflow.tensorflow.log_model(inference_function, "model")

运行此命令将给出错误
mlflow.exceptions.MlflowException:未知的模型类型:<class 'tensorflow.python.eager.wrap_function.WrappedFunction'>

mlflow.exceptions.MlflowException:未知模型类型:<class 'tensorflow.python.trackable.autotrackable. AutoTrackable'>
.log_model()函数的MLFlow文档说明第一个参数应该是“TF2核心模型(继承tf.Module)或MLflow模型格式的Keras模型”。如何从AutoTrackable中获取可以从tf.saved_model.load()获取的日志?

tgabmvqs

tgabmvqs1#

您看到的错误是因为tf.saved_model.load()返回的对象不是直接的TensorFlow 2核心模型(继承tf.Module)或Keras模型,这是mlflow.tensorflow.log_model()所期望的。相反,tf.saved_model.load()返回AutoTrackable对象,这是序列化TensorFlow对象的更通用表示。
要使用MLflow记录模型,可以执行以下步骤:
1.将SavedModel转换为ConcreteFunction:这将允许您从加载的模型中获得可调用的TensorFlow函数,可以用于进行预测。
1.创建自定义tf.Module:TensorFlow 2.x鼓励使用tf.Module封装模型。由于MLflow需要tf.Module或Keras模型,因此您可以创建自定义模块并将ConcreteFunction添加到其中。
1.使用MLflow记录自定义tf.Module:现在您有了一个tf.Module,您可以使用MLflow记录它。
以下是如何实现这些步骤:

import tensorflow as tf
import mlflow

# Load the SavedModel
imported = tf.saved_model.load(my_model_directory)

# Convert the SavedModel to a ConcreteFunction
inference_function = imported.signatures["serving_default"]

# Create a custom tf.Module
class CustomModule(tf.Module):
    def __init__(self, inference_function):
        self.inference_function = inference_function

    @tf.function(input_signature=inference_function.input_signature)
    def call(self, *args, **kwargs):
        return self.inference_function(*args, **kwargs)

# Create an instance of the custom module with the inference function
module_instance = CustomModule(inference_function)

# Log the model with MLflow
mlflow.set_tracking_uri("http://127.0.0.1:5000")
with mlflow.start_run() as run:
    mlflow.tensorflow.log_model(module_instance, "model")

这种方法将ConcreteFunction Package 在自定义tf.Module中,并且应该满足使用MLflow进行日志记录的要求。
参考文献:
TensorFlow 2 SavedModel documentation
MLflow TensorFlow documentation

相关问题