keras 训练和验证精度高,测试精度差

llycmphe  于 2022-11-24  发布在  其他
关注(0)|答案(2)|浏览(251)

我尝试对两类图像进行分类。虽然在10个历元之后我得到了很高的训练和验证精度(0.97),但我的测试结果很糟糕(精度0.48),混淆矩阵显示网络预测了错误类别的图像(随附结果)。
数据集中只有2个类,每个类有10个,000个图像示例(增强后)。我使用的是VGG16网络。整个数据集被分割为测试集的20(通过从每个类中随机获取图像来执行该分割,因此其被混洗)。剩余图像被分割为80%训练集和20%有效集(如ImageDataGenerator行的代码所示).所以最后还有:
12,904个训练图像属于2个类别
3,224个有效图像属于2类
4,032张测试图像属于2类
这是我的代码:

def CNN(CNN='VGG16', choice='predict', prediction='./dataset/Test/image.jpg'):
    ''' Train images using one of several CNNs '''
    Train   = './dataset/Train'
    Tests   = './dataset/Test'
    shape   = (224, 224)
    epochs  = 10
    batches = 16
    classes = []
    for c in os.listdir(Train): classes.append(c)
    IDG = keras.preprocessing.image.ImageDataGenerator(validation_split=0.2)
    train = IDG.flow_from_directory(Train, target_size=shape, color_mode='rgb',
        classes=classes, batch_size=batches, shuffle=True, subset='training')
    valid = IDG.flow_from_directory(Train, target_size=shape, color_mode='rgb',
        classes=classes, batch_size=batches, shuffle=True, subset='validation')
    tests = IDG.flow_from_directory(Tests, target_size=shape, color_mode='rgb',
        classes=classes, batch_size=batches, shuffle=True)
    input_shape = train.image_shape
    if CNN == 'VGG16' or 'vgg16':
        model = VGG16(weights=None, input_shape=input_shape,
            classes=len(classes))
    elif CNN == 'VGG19' or 'vgg19':
        model = VGG19(weights=None, input_shape=input_shape,
            classes=len(classes))
    elif CNN == 'ResNet50' or 'resnet50':
        model = ResNet50(weights=None, input_shape=input_shape,
            classes=len(classes))
    elif CNN == 'DenseNet201' or 'densenet201':
        model = DenseNet201(weights=None, input_shape=input_shape,
            classes=len(classes))
    model.compile(optimizer=keras.optimizers.SGD(
            lr=1e-3,
            decay=1e-6,
            momentum=0.9,
            nesterov=True),
            loss='categorical_crossentropy',
            metrics=['accuracy'])
    Esteps = int(train.samples/train.next()[0].shape[0])
    Vsteps = int(valid.samples/valid.next()[0].shape[0])
    if choice == 'train':
        history= model.fit_generator(train,
            steps_per_epoch=Esteps,
            epochs=epochs,
            validation_data=valid,
            validation_steps=Vsteps,
            verbose=1)
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
        plt.show()
        plt.plot(history.history['acc'])
        plt.plot(history.history['val_acc'])
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
        plt.show()
        Y_pred = model.predict_generator(tests, verbose=1)
        y_pred = np.argmax(Y_pred, axis=1)
        matrix = confusion_matrix(tests.classes, y_pred)
        df_cm  = pd.DataFrame(matrix, index=classes, columns=classes)
        plt.figure(figsize=(10,7))
        sn.heatmap(df_cm, annot=True)
        print(classification_report(tests.classes,y_pred,target_names=classes))
        model.save_weights('weights.h5')
    elif choice == 'predict':
        model.load_weights('./weights.h5')
        img = image.load_img(prediction, target_size=shape)
        im = image.img_to_array(img)
        im = np.expand_dims(im, axis=0)
        if CNN == 'VGG16' or 'vgg16':
            im = keras.applications.vgg16.preprocess_input(im)
            prediction = model.predict(im)
            print(prediction)
        elif CNN == 'VGG19' or 'vgg19':
            im = keras.applications.vgg19.preprocess_input(im)
            prediction = model.predict(im)
            print(prediction)
        elif CNN == 'ResNet50' or 'resnet50':
            im = keras.applications.resnet50.preprocess_input(im)
            prediction = model.predict(im)
            print(prediction)
            print(keras.applications.resnet50.decode_predictions(prediction))
        elif CNN == 'DenseNet201' or 'densenet201':
            im = keras.applications.densenet201.preprocess_input(im)
            prediction = model.predict(im)
            print(prediction)
            print(keras.applications.densenet201.decode_predictions(prediction))

CNN(CNN='VGG16', choice='train')

结果:
第一次

precision    recall  f1-score   support
Predator       0.49      0.49      0.49      2016
Omnivore       0.49      0.49      0.49      2016
accuracy       --        --        0.49      4032

我怀疑ImageDataGenerator()没有在train/valid split "之前"对图像进行洗牌。如果是这样的话,我该如何强制Keras中的ImageDataGenerator在split之前对数据集进行洗牌呢?
如果洗牌不是这样,我怎么能解决我的问题?我做错了什么?

fykwrbwg

fykwrbwg1#

所以你的模型基本上是过度拟合的,这意味着它在"记忆"你的训练集。我有几点建议:
1.检查2个预测类在训练集中是否平衡。例如,0和1各占一半。例如,如果90%的训练数据标记为0,则模型将简单地预测所有数据为0,并在90%的时间内正确进行验证。
1.如果您的训练数据已经平衡,则意味着您的模型没有泛化。也许您可以尝试使用预训练的模型,而不是自定义训练VGG的每一层?您可以加载VGG的预训练权重,但不包括顶层,并仅训练密集层。
1.使用交叉验证。重新排列每个验证中的数据,并查看测试集中的结果是否有所改善。

nuypyhwy

nuypyhwy2#

不知何故,Keras的图像生成器在与fit()或fit_generator()函数组合时工作得很好,但在与predict_generator()或predict()函数组合时却失败得很惨。
当使用AMD处理器的Plaid-ML Keras后端时,我宁愿一个接一个地循环所有测试图像,并在每次迭代中获得每个图像的预测。

import os
from PIL import Image
import keras
import numpy

# code for creating dan training model is not included

print("Prediction result:")
dir = "/path/to/test/images"
files = os.listdir(dir)
correct = 0
total = 0
#dictionary to label all animal category class.
classes = {
    0:'This is Cat',
    1:'This is Dog',
}
for file_name in files:
    total += 1
    image = Image.open(dir + "/" + file_name).convert('RGB')
    image = image.resize((100,100))
    image = numpy.expand_dims(image, axis=0)
    image = numpy.array(image)
    image = image/255
    pred = model.predict_classes([image])[0]
    animals_category = classes[pred]
    if ("cat" in file_name) and ("cat" in sign):
        print(correct,". ", file_name, animals_category)
        correct+=1
    elif ("dog" in file_name) and ("dog" in animals_category):
        print(correct,". ", file_name, animals_category)
        correct+=1
print("accuracy: ", (correct/total))

相关问题