Keras:从图像数据生成器或预测生成器获取真标签(y_test)

ocebsuys  于 2023-01-13  发布在  其他
关注(0)|答案(4)|浏览(207)

我正在使用ImageDataGenerator().flow_from_directory(...)从目录生成批数据。
在模型成功构建之后,我希望得到一个包含True和Predicted类标签的两列数组。使用model.predict_generator(validation_generator, steps=NUM_STEPS),我可以得到一个包含预测类的numpy数组。是否可以让predict_generator输出相应的True类标签?
要添加的是:validation_generator.classes确实打印True标签,但是按照从目录中检索它们的顺序,它没有考虑批处理或通过增强进行的样本扩展。

du7egjpx

du7egjpx1#

您可以通过以下方式获取预测标签:

y_pred = numpy.rint(predictions)

你可以通过以下方式得到真正的标签:

y_true = validation_generator.classes

在此之前,您应该在验证生成器中设置shuffle=False
最后,您可以通过以下方式打印混淆矩阵
print confusion_matrix(y_true, y_pred)

jv4diomz

jv4diomz2#

还有另一种稍微“黑客”一点的方法来检索真正的标签。请注意,这种方法可以处理在生成器中设置shuffle=True的情况(一般来说,打乱数据是一个好主意-无论是在存储数据的地方手动进行,还是通过生成器进行,这可能更容易)。

# Create lists for storing the predictions and labels
predictions = []
labels = []

# Get the total number of labels in generator 
# (i.e. the length of the dataset where the generator generates batches from)
n = len(generator.labels)

# Loop over the generator
for data, label in generator:
    # Make predictions on data using the model. Store the results.
    predictions.extend(model.predict(data).flatten())

    # Store corresponding labels
    labels.extend(label)

    # We have to break out from the generator when we've processed 
    # the entire once (otherwise we would end up with duplicates). 
    if (len(label) < generator.batch_size) and (len(predictions) == n):
        break

您的预测和相应的标签现在应该分别存储在predictionslabels中。
最后,请记住,我们不应该在验证和测试集/生成器上添加数据扩充。

noj0wjuj

noj0wjuj3#

使用np.rint()方法会得到一个像[1.,0.,0.]这样的热编码结果,当我尝试用confusion_matrix(y_true, y_pred)创建一个混淆矩阵时,它导致了错误。因为validation_generator.classes返回的类标签是一个数字。
为了获取类编号,例如0,1,2作为指定的类标签,我发现本主题中的选择答案很有用。here

uyto3xhc

uyto3xhc4#

您应该尝试使用此方法来解析类概率,并根据得分将其转换为单个类。

if Y_preds.ndim !=1:
    Y_preds = np.argmax(Y_preds, axis=1)

相关问题