- 此问题在此处已有答案**:
Python: pickling nested functions(3个答案)
2天前关闭。
当我尝试pickle一个在类方法中生成的Keras对象时,我得到了一个错误。
以下是类中的相关代码:
class DeepLearning:
epochs = 0
batch_size = 0
model_layers = []
def get_model(self, input_dim: int):
def build_model(input_dim=0, model_layers=[]):
model = keras.Sequential()
for model_layer in model_layers:
if model_layer.type == 0:
model.add(layers.Dense(model_layer.units))
elif model_layer.type == 1:
model.add(layers.LSTM(model_layer.units))
return model
model = KerasClassifier(build_fn=build_model,
input_dim=input_dim,
model_layers=self.model_layers)
return model
下面是示例化类并获取Keras对象的方法:
dl = DeepLearning()
# ... set DeepLearning attributes
model = dl.get_model(10)
joblib.dump(model, 'some/path/to/file') # this fails
默认值培训:无法pickle〈函数深度学习. get_model .. build_model位于0x7f6ea67c2ee0〉:它未作为app. service. cm. core. models.深度学习.深度学习. get_model .. build_model找到
此错误是什么?如何修复?
1条答案
按热度按时间wyyhbhjk1#
在我的M1 MacBook上,我可以运行你的代码。我已经粘贴了我运行的代码,做了一些修改,看看它是否与你的匹配。
我所做的更改:
1.添加了import语句,因为它不在您的代码中,而我需要它们来执行
1.将
layers.Dense
替换为tf.keras.layers.Dense
,应该没有太大区别。