matplotlib 并排绘制多个混淆矩阵[重复]

rur96b6h  于 2023-03-09  发布在  其他
关注(0)|答案(1)|浏览(263)
    • 此问题在此处已有答案**:

(13个答案)
20小时前关门了。
我有一个3混淆矩阵,每一个相关的算法(SVM,LR,RF)
这是我的代码的混淆矩阵的SVM,LR和RF:

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
svm_A = [[11,1], 
        [2,2]]
df_cm = pd.DataFrame(svm_A, index = [i for i in "01"],
                  columns = [i for i in "01"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)

RF_A = [[12,0], 
        [3,1]]
df_cm = pd.DataFrame(RF_A, index = [i for i in "01"],
                  columns = [i for i in "01"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)

LR_A = [[11,1], 
        [2,2]]
df_cm = pd.DataFrame(RF_A, index = [i for i in "01"],
                  columns = [i for i in "01"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)

我想把三个情节并排放在一起。
我该怎么做呢?

w8biq8rn

w8biq8rn1#

您可以使用这种模块化的方法,灵活地添加更多的算法(在data中添加一个新的dict),并在一个地方修改绘制每个混淆矩阵的代码。

  • 密码 *
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt

svm_A = [[11,1],
        [2,2]]
rf_A = [[12,0],
        [3,1]]
lr_A = [[11,1],
        [2,2]]

data = [
    {"algorithm_name": "SVM", "confusion_matrix": svm_A},
    {"algorithm_name": "RF", "confusion_matrix": rf_A},
    {"algorithm_name": "LR", "confusion_matrix": lr_A},
]

n_algorithms = len(data)
fig, axs = plt.subplots(1, n_algorithms, figsize=(5 * n_algorithms, 4.5))

for i, algo_data in enumerate(data):
    df_cm = pd.DataFrame(algo_data["confusion_matrix"], index=[j for j in "01"],
                         columns=[j for j in "01"])
    sn.heatmap(df_cm, annot=True, ax=axs[i])
    axs[i].set_title(algo_data["algorithm_name"])

plt.tight_layout()
plt.show()
  • 绘图 *

相关问题