keras 如何绘制混淆矩阵

gv8xihay  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(177)

我试图用混淆矩阵来评估我的renet 50模型,但混淆矩阵看起来像这样:

matrix = confusion_matrix(y_test, y_pred, normalize="pred")
print(matrix)
    
# output
array([[1, 0],
      [1, 2]], dtype=int64)

字符串
我使用scikit-learn生成混淆矩阵,使用tf keras制作模型
但有没有办法绘制/可视化混淆矩阵?
我已经尝试使用sklearn.metrics.plot_confusion_matrix(matrix)
这个:How to plot Confusion Matrix但是我得到了这个:
x1c 0d1x的数据

gijlo24d

gijlo24d1#

包括以下导入:

from sklearn.metrics import ConfusionMatrixDisplay 
from matplotlib import pyplot as plt

字符串
现在,调用ConfusionMatrixDisplay函数并将matrix作为参数传递,如下所示:

disp = ConfusionMatrixDisplay(confusion_matrix=matrix) 
# Then just plot it: 
disp.plot() 
# And show it: 
plt.show()


此外,可以在ConfusionMatrixDisplay函数中将normalize参数设置为True,以在图中显示归一化计数。查看docs以获取更多参考和其他可接受的参数。

相关问题