keras 无法pickle具有嵌套函数的类方法[重复]

vlju58qv  于 2023-02-08  发布在  其他
关注(0)|答案(1)|浏览(115)
    • 此问题在此处已有答案**:

Python: pickling nested functions(3个答案)
2天前关闭。
当我尝试pickle一个在类方法中生成的Keras对象时,我得到了一个错误。
以下是类中的相关代码:

class DeepLearning:

    epochs = 0
    batch_size = 0
    model_layers = []

    def get_model(self, input_dim: int):

        def build_model(input_dim=0, model_layers=[]):

            model = keras.Sequential()
        
            for model_layer in model_layers:

                if model_layer.type == 0:
                    model.add(layers.Dense(model_layer.units))
                elif model_layer.type == 1:
                    model.add(layers.LSTM(model_layer.units))

            return model

       model = KerasClassifier(build_fn=build_model, 
                                input_dim=input_dim, 
                                model_layers=self.model_layers)

      return model

下面是示例化类并获取Keras对象的方法:

dl = DeepLearning()
# ... set DeepLearning attributes
model = dl.get_model(10)
joblib.dump(model, 'some/path/to/file')  # this fails

默认值培训:无法pickle〈函数深度学习. get_model .. build_model位于0x7f6ea67c2ee0〉:它未作为app. service. cm. core. models.深度学习.深度学习. get_model .. build_model找到
此错误是什么?如何修复?

wyyhbhjk

wyyhbhjk1#

from tensorflow import keras
import tensorflow as tf
import joblib
from scikeras.wrappers import KerasClassifier

class DeepLearning:

    epochs = 0
    batch_size = 0
    model_layers = []

    def get_model(self, input_dim: int):
        def build_model(input_dim=0, model_layers=[]):

            model = keras.Sequential()

            for model_layer in model_layers:

                if model_layer.type == 0:
                    model.add(tf.keras.layers.Dense(model_layer.units))
                elif model_layer.type == 1:
                    model.add(tf.keras.layers.LSTM(model_layer.units))

            return model

        model = KerasClassifier(
            build_fn=build_model, input_dim=input_dim, model_layers=self.model_layers
        )

        return model

dl = DeepLearning()
# ... set DeepLearning attributes
model = dl.get_model(10)
joblib.dump(model, "model")

在我的M1 MacBook上,我可以运行你的代码。我已经粘贴了我运行的代码,做了一些修改,看看它是否与你的匹配。
我所做的更改:
1.添加了import语句,因为它不在您的代码中,而我需要它们来执行
1.将layers.Dense替换为tf.keras.layers.Dense,应该没有太大区别。

相关问题