我觉得我错过了什么小东西,但似乎不能弄清楚它是什么。我试图在cassava数据集上训练一个非常简单的模型,但当我调用fit
函数时,输入名称与预期名称不匹配。我尝试命名输入层以匹配模型,但tf坚持将_input附加到层名,导致冲突。我敢肯定这是tfds的一个相当典型的用例,而且一定是一些微不足道的东西。
错误:
ValueError: Missing data for input "flatten_input". You passed a data dictionary with keys ['image', 'image/filename', 'label']. Expected the following keys: ['flatten_input']
我从一个github项目中借用了查看代码,这个代码确实可以工作,因为我可以查看加载的数据。
# tensorflow 2.x core api
import logging
from mlflow.models import infer_signature
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras as K
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
#############################################################################################################
from matplotlib import pyplot as plt
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(int(x), int(y), i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
# dataset, info = tfds.load('cassava', with_info=True)
dataset, info = tfds.load("cassava", shuffle_files=True, with_info=True)
print("INFO:\n", info)
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
def create_model(type="default", n_classes=6):
if type == "something":
pass
else:
model = K.Sequential()
# naming the below layer with name argument still appends the _input to the actually name
model.add(K.layers.Flatten(input_shape=(244, 244, 3)))
model.add(K.layers.Dense(512, activation="relu"))
model.add(K.layers.Dense(256, activation="relu"))
model.add(K.layers.Dense(128, activation="relu"))
model.add(K.layers.Dense(64, activation="relu"))
model.add(K.layers.Dense(n_classes, activation="softmax"))
model.compile(loss='sparse_categorical_crossentropy', optimizer=K.optimizers.Adam(0.01), metrics=['accuracy'])
return model
# batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
# examples = next(batch)
# plot(examples)
print(tf.__version__)
model = create_model()
model.fit(dataset["train"], epochs=5)
1条答案
按热度按时间mnemlml81#
也许你可以考虑使用tensorflow_datasets来查看这个tf示例:https://www.tensorflow.org/guide/keras/transfer_learning。
相应地,您可以按如下方式重写数据加载器代码
定义一个新的数据预处理函数,使其同时适应映像创建和批处理创建
然后,您应该能够训练您的模型
希望这对你有帮助!