keras 如何绘制混淆矩阵

gv8xihay  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(176)

我试图用混淆矩阵来评估我的renet 50模型,但混淆矩阵看起来像这样:

  1. matrix = confusion_matrix(y_test, y_pred, normalize="pred")
  2. print(matrix)
  3. # output
  4. array([[1, 0],
  5. [1, 2]], dtype=int64)

字符串
我使用scikit-learn生成混淆矩阵,使用tf keras制作模型
但有没有办法绘制/可视化混淆矩阵?
我已经尝试使用sklearn.metrics.plot_confusion_matrix(matrix)
这个:How to plot Confusion Matrix但是我得到了这个:
x1c 0d1x的数据

gijlo24d

gijlo24d1#

包括以下导入:

  1. from sklearn.metrics import ConfusionMatrixDisplay
  2. from matplotlib import pyplot as plt

字符串
现在,调用ConfusionMatrixDisplay函数并将matrix作为参数传递,如下所示:

  1. disp = ConfusionMatrixDisplay(confusion_matrix=matrix)
  2. # Then just plot it:
  3. disp.plot()
  4. # And show it:
  5. plt.show()


此外,可以在ConfusionMatrixDisplay函数中将normalize参数设置为True,以在图中显示归一化计数。查看docs以获取更多参考和其他可接受的参数。

相关问题