keras MNIST使用0~5个标签创建多项模型(低准确度)

7kqas0il  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(172)

我应该运行一个来自MNIST数据的0,1,2,3,4,5标签的模型,并检查准确性。
这是我得到的:

> import tensorflow as tf
from tensorflow import keras

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train.shape
y_train.shape

y_train[0:10]

x_train_new, y_train_new = x_train[(y_train==0) | (y_train==1) | (y_train==2) | (y_train==3) | (y_train==4) | (y_train==5)], y_train[(y_train==0) | (y_train==1) | (y_train==2) | (y_train==3) | (y_train==4) | (y_train==5)]

x_train_new.shape
y_train_new.shape

y_train_new[0:10]

y_train_onehot = tf.one_hot(y_train_new, depth=6)
y_test_onehot = tf.one_hot(y_test, depth=6)

x_train_final = x_train_new.reshape((-1, 784))
x_train_final.shape

x_test_new, y_test_new = x_test[(y_test==0) | (y_test==1) | (y_test==2) | (y_test==3) | (y_test==4) | (y_test==5)], y_test[(y_test==0) | (y_test==1) | (y_test==2) | (y_test==3) | (y_test==4) | (y_test==5)]
x_test_new.shape
x_test_final = x_test_new.reshape((-1, 784))

x_train_final = x_train_final / 255
x_test_final = x_test_final / 255

model = keras.Sequential([keras.layers.Dense(1,activation='softmax')])
model.compile(optimizer="sgd",loss="categorical_crossentropy",metrics=["accuracy"])

model.fit(x=x_train_final,y=y_train_new,epochs=5)

但是,运行后的准确性非常低(0.1872)。当我尝试将Dense从1更改为6时,得到“ValueError:形状(None,1)和(None,6)不兼容”。那么问题是什么?有人能帮助我修复我的代码吗?:(TIA

zzwlnbp8

zzwlnbp81#

在训练模型时,您不会将one_hot编码标签传递到model.fit。此外,现在您有6个标签(0、1、2、3、4、5),您需要根据提供的数据集,在模型的最后一层提及这些标签的类别计数。
请检查以下固定代码:

model = keras.Sequential([keras.layers.Dense(6,activation='softmax')])
model.compile(optimizer="sgd",loss="categorical_crossentropy",metrics=["accuracy"])

model.fit(x=x_train_final,y=y_train_onehot,epochs=5)

输出量:

Epoch 1/5
1126/1126 [==============================] - 5s 4ms/step - loss: 0.5173 - accuracy: 0.8729
Epoch 2/5
1126/1126 [==============================] - 3s 3ms/step - loss: 0.2819 - accuracy: 0.9251
Epoch 3/5
1126/1126 [==============================] - 3s 3ms/step - loss: 0.2440 - accuracy: 0.9325
Epoch 4/5
1126/1126 [==============================] - 3s 3ms/step - loss: 0.2251 - accuracy: 0.9365
Epoch 5/5
1126/1126 [==============================] - 3s 3ms/step - loss: 0.2133 - accuracy: 0.9387
<keras.callbacks.History at 0x7f22bd00cf10>

请参阅类似的link以了解更多详细信息。

相关问题