keras Tensorflow的周期组模型:拟合误差法

5lwkijsr  于 2023-01-05  发布在  其他
关注(0)|答案(1)|浏览(150)

我试着从这里运行代码:https://keras.io/examples/generative/cyclegan/,但在运行model.fit(..)时,我收到以下错误:

ValueError: Model <__main__.CycleGan object at 0x7fec3de767c0> cannot be saved either because the
input shape is not available or because the forward pass of the model is not defined.To define a 
forward pass, please override `Model.call()`. To specify an input shape, either call 
`build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or 
`Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass
in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

即使我通过作者链接的colab文件运行它。是否有其他人已经遇到过这个问题并知道如何解决它?我也尝试通过model.build((batch_size,256,256,3))预定义inputsize,但仍然得到相同的错误。
如果我对回调进行注解,它就可以工作。我认为问题出在model_checkpoint_callback中。没有它,代码可以工作,但是我不能保存模型。
非常感谢提前为每一个答案!

fnatzsnv

fnatzsnv1#

解决方案是添加以下行:

# Create cycle gan model
cycle_gan_model = CycleGan(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)
# add following line:
cycle_gan_model.compute_output_shape(input_shape=(None, 256, 256, 3))

相关问题