我在尝试训练tensorflow 模型时遇到此错误:
ValueError: Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 800, 800, 1), found shape=(None, 640000, 1)
这是标签。形状:
(100,)
这是imgs.shape(展开尺寸后):
(100, 640000, 1)
每个图像都是800x800像素,并有1个颜色通道,但当我打印(imgs[0].shape)时,它会给我:
(640000, 1)
如何制作imgs_train的形状:将(100,640000,1)改为(100,800,800,1)
下面是完整的代码:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from bidict import bidict
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix
ENCODER = bidict({
'1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6,
'7': 7, '8': 8, '9': 9, '10': 10
})
labels = np.load('data/labels.npy')
labels = np.array([ENCODER[x] for x in labels])
print(labels.shape)
imgs = np.load('data/images.npy')
imgs = imgs.astype("float32") / 255
print(imgs.shape)
imgs = np.expand_dims(imgs, -1)
print(imgs.shape)
labels, imgs = shuffle(labels, imgs)
split = .75
labels_train = labels[:int(len(labels) * split)]
labels_test = labels[int(len(labels) * split):]
imgs_train = imgs[:int(len(imgs) * split)]
imgs_test = imgs[int(len(imgs) * split):]
batch_size = 32
epochs = 10
model = keras.Sequential([
keras.Input(shape=(800, 800, 1)),
layers.Conv2D(256, kernel_size=5, activation='relu'),
layers.MaxPooling2D(pool_size=2),
layers.Dropout(0.3),
layers.Conv2D(512, kernel_size=5, activation='relu'),
layers.MaxPooling2D(pool_size=2),
layers.Dropout(0.3),
layers.Conv2D(1024, kernel_size=5, activation='relu'),
layers.MaxPooling2D(pool_size=2),
layers.Dropout(0.3),
layers.Flatten(),
layers.Dense(len(ENCODER)+1, activation='softmax')
])
early_stopping = keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=2)
optimizer = keras.optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
model.fit(imgs_train,
labels_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(imgs_test, labels_test),
callbacks=[early_stopping])
model.save("alphabet_detection.h5")
1条答案
按热度按时间1szpjjfi1#
打电话试试
在创建数据集之前
看看有没有帮助。