我尝试使用Tensorflow的KMNIST数据集和我正在使用的教科书中的一些示例代码构建一个简单的自动编码器,但在尝试拟合模型时总是出现错误。
错误显示为ValueError: Layer sequential_20 expects 1 inputs, but it received 2 input tensors.
我是TensorFlow的新手,我对这个错误的所有研究都让我感到困惑,因为它似乎涉及到我代码中没有的东西。This thread没有帮助,因为我只使用顺序层。
完整代码:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import pandas as pd
import matplotlib.pyplot as plt
#data = tfds.load(name = 'kmnist')
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
name = 'kmnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
img_train = img_train.squeeze()
img_test = img_test.squeeze()
## From Hands on Machine Learning Textbook, chapter 17
stacked_encoder = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(100, activation="selu"),
keras.layers.Dense(30, activation="selu"),
])
stacked_decoder = keras.models.Sequential([
keras.layers.Dense(100, activation="selu", input_shape=[30]),
keras.layers.Dense(28 * 28, activation="sigmoid"),
keras.layers.Reshape([28, 28])
])
stacked_ae = keras.models.Sequential([stacked_encoder, stacked_decoder])
stacked_ae.compile(loss="binary_crossentropy",
optimizer=keras.optimizers.SGD(lr=1.5))
history = stacked_ae.fit(img_train, img_train, epochs=10,
validation_data=[img_test, img_test])
6条答案
按热度按时间62lalag41#
当我改变的时候它帮助了我:
validation_data=[X_val, y_val]
到validation_data=(X_val, y_val)
你还在想为什么?
svmlkihl2#
使用
validation_data=(img_test, img_test)
代替validation_data=[img_test, img_test]
以下是编码器和解码器组合在一起的示例:
u0sqgete3#
正如Keras API参考(link)中所述,
验证数据:...验证数据可以是:- Numpy数组或Tensor的
tuple
(x_瓦尔,y_val)-Numpy数组的tuple
(x_val,y_val,val_sample_weights)-数据集...因此,validation_data必须是元组而不是列表(Numpy数组或Tensor),我们应该使用圆括号
(...)
,而不是方括号[...]
。但是,根据我有限的经验,TensorFlow 2.0.0对方括号的使用无所谓,但TensorFlow 2.3.0会抱怨。如果您的脚本在TF 2.0下运行,而不是在TF 2.3下运行,那么它就可以正常运行。
yyyllmsg4#
您已经两次给出数据而不是标签:
代替
2fjabf4q5#
在解决方案中,有些人说你需要把手镯改成圆括号,但在Colab中不起作用。是的,把
validation_data=[X_val, y_val]
改成validation_data=(X_val, y_val)
应该可以,因为它是所需的格式,但在tf==2.5.0(在Google Colab中)中,它没有解决问题。我把函数API改成了序列API,这解决了问题。奇怪。whhtz7ly6#
这个错误也可能是由于向
model.fit()
提交了错误的对象而触发的。当我想执行
与