我一直在尝试使用Tensorflow和Keras构建一个图像分类模型。该模型应该对显示的手指数量进行分类。
我尝试过用各种不同的方法训练模型,但是现在,尽管训练和验证的准确率达到了95%以上,模型每次都以相同的置信度返回相同的预测。即使我尝试在训练模型的图像上预测类,它仍然返回相同的结果。
数据样本:
最后一次训练:
train_ds = tf.keras.utils.image_dataset_from_directory(
dataset_path,
image_size=(300,300),
seed=123,
batch_size=batch_size,
color_mode='grayscale')
val_ds = tf.keras.utils.image_dataset_from_directory(
dataset_pathh,
image_size=(300,300),
seed=123,
batch_size=batch_size,
color_mode='grayscale')
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
num_classes = len(class_names)
model = Sequential([
layers.Rescaling(1./255),
layers.Conv2D(32, 3, activation='relu', input_shape=(300, 300, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(512, activation='relu'),
layers.Dense(num_classes, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
epochs=5
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
培训结果:
得到预测:
img = tf.keras.utils.load_img(
img_path, target_size=(300, 300)
)
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
img_array = tf.image.rgb_to_grayscale(img_array)
img_array = img_array / 255
predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])
返回的预测:
tf.Tensor([0.1295672 0.1295672 0.1295672 0.1295672 0.35214293 0.12958823], shape=(6,), dtype=float32)
我是这方面的初学者,主要是遵循Tensorflow教程。
编辑:忘记提到数据集是平衡的,所以问题不存在。
1条答案
按热度按时间ckocjqey1#
在没有访问数据的情况下很难确定确切的原因,但是原始代码中有一些错误。修复这些错误可以提高模型性能。
首先,损失函数期望未标度的logits。这些是层在任何激活之前返回的值。然而,在网络架构中,在最后一层有一个softmax激活。要解决此问题,请将激活设置为
None
,以便该层返回未缩放的logits。第二,模型实现包含了图像缩放,但也要在评估代码中缩放图像。在评估代码中,这会导致图像缩放 * 两次 *。删除评估代码中的以下行:
除了实现中的这些错误之外,数据集的不平衡也可能导致性能问题。考虑添加AUROC和AUPR等指标,这有助于给予更好的性能视图。