有没有办法把非tf函数作为SavedModel Signature嵌入到tf.Keras模型图中?

x8goxv8g  于 2023-02-19  发布在  其他
关注(0)|答案(1)|浏览(135)

我想将预处理函数和方法作为SavedModel签名添加到模型图中。
示例:

# suppose we have a keras model
# ...

# defining the function I want to add to the model graph
@tf.function
def process(model, img_path):
    # do some preprocessing using different libs. and modules...

    outputs = {"preds": model.predict(preprocessed_img)}
    return outputs

# saving the model with a custom signature
tf.saved_model.save(new_model, dst_path, 
                    signatures={"process": process})

或者我们可以在这里使用tf.Module。但是,问题是我不能将自定义函数嵌入到保存的模型图中。
有什么办法可以做到吗?

xwmevbvl

xwmevbvl1#

我认为您对Tensorflow中save_model方法的用途有一点误解。
根据documentation,其目的是提供一种方法来序列化模型的图形,以便之后可以使用load_model进行loaded
load_model返回的模型是tf.Module的一个类,包含了它的所有方法和属性,而你需要序列化预测管道。
老实说,我不知道有什么好的方法可以做到这一点,但是你可以做的是使用一个不同的方法来序列化你的预处理参数,例如pickle或一个不同的方法,由你使用的框架提供,并在其上编写一个类,这将做以下事情:

class MyModel:
    def __init__(self, model_path, preprocessing_path):
        self.model = load_model(model_path)
        self.preprocessing = load_preprocessing(preprocessing_path)

    def predict(self, img_path):
        return self.model.predict(self.preprocessing(img_path))

相关问题