matplotlib 如何绘制矩阵并为每列和每行提供说明

nszi6y05  于 2023-02-23  发布在  其他
关注(0)|答案(1)|浏览(123)

我有一个需要扩充的数据集,因此,我实现了一种称为幅度扭曲的扩充方法,该方法需要调整两个超参数,即sigmaknots。我有两个模型,我使用增强数据训练,并在部分真实的数据上测试,为了比较准确性,我也只在真实数据上训练模型。让我们假设下面的Python代码:

# test accuracy trained on real data only
ref_dt_accuracy = 0.86 
ref_lstm_accuracy = 0.85 

# test accuracy for each pair of hyperparameters
sigma = [0.2, 0.35, 0.5, 0.65]
knots = [4,5,6,7]

dt_accuracy_mw = [
[0.82, 0.85, 0.83, 0.84], 
[0.8, 0.79, 0.81, 0.79], 
[0.78,0.77, 0.74, 0.76], 
[0.74, 0.72, 0.78, 0.70]
]

lstm_accuracy_mw = [
[0.80, 0.83, 0.81, 0.82], 
[0.78, 0.77, 0.79, 0.77], 
[0.76,0.75, 0.72, 0.74], 
[0.72, 0.7, 0.76, 0.68]
]

现在,我想绘制两个(如果最后一个选项可行,则为三个)矩阵:
1.绘制dt_accuracy_mwlstm_accuracy_mw,使每个sigmaknots可视化:

sigma/knots 4  5  6  7
    0.2
    0.35    Actual matrix consisting of aforementioned accuracies
    0.5
    0.65

1.上述内容的组合版本,每个条目由dt_accuracy (ref_dt_accuracy - dt_accuracy)/lstm_accuracy (ref_lstm_accuracy - lstm_accuracy)组成,因此每个条目由dt_accuracy组成,dt_accuracy是与参考值的差值,lstm_accuracy是相同的。模型的每个准确度评分由/分隔
如何使用任何开源库(如matplotlib、seaborn等)来实现这一点?

7y4bm7vi

7y4bm7vi1#

您可以按如下方式创建海运热图:

from matplotlib import pyplot as plt
import seaborn as sns

sigma = [0.2, 0.35, 0.5, 0.65]
knots = [4, 5, 6, 7]

dt_accuracy_mw = [[0.82, 0.85, 0.83, 0.84],
                  [0.8, 0.79, 0.81, 0.79],
                  [0.78, 0.77, 0.74, 0.76],
                  [0.74, 0.72, 0.78, 0.70]]

ax = sns.heatmap(data=dt_accuracy_mw, xticklabels=knots, yticklabels=sigma,
                 linewidths=1, linecolor='blue', clip_on=False, annot=True, cbar=False,
                 cmap=sns.color_palette(['white'], as_cmap=True))
ax.set_xlabel('knots')
ax.set_ylabel('sigma')
plt.tight_layout()
plt.show()

如果我正确理解了第二个问题,那么注解矩阵就可以完成这项工作(data可以是具有正确宽度和高度的任何东西):

from matplotlib import pyplot as plt
import seaborn as sns

ref_dt_accuracy = 0.86
ref_lstm_accuracy = 0.85

sigma = [0.2, 0.35, 0.5, 0.65]
knots = [4, 5, 6, 7]

dt_accuracy_mw = [[0.82, 0.85, 0.83, 0.84],
                  [0.8, 0.79, 0.81, 0.79],
                  [0.78, 0.77, 0.74, 0.76],
                  [0.74, 0.72, 0.78, 0.70]]

lstm_accuracy_mw = [[0.80, 0.83, 0.81, 0.82],
                    [0.78, 0.77, 0.79, 0.77],
                    [0.76, 0.75, 0.72, 0.74],
                    [0.72, 0.7, 0.76, 0.68]]
annot_matrix = [[f'{ref_dt_accuracy - dt:.2f} / {ref_lstm_accuracy - lstm:.2f}'
                 for dt, lstm in zip(dt_row, lstm_row)]
                for dt_row, lstm_row in zip(dt_accuracy_mw, lstm_accuracy_mw)]

ax = sns.heatmap(data=dt_accuracy_mw, xticklabels=knots, yticklabels=sigma,
                 annot=annot_matrix, fmt='',
                 linewidths=2, linecolor='crimson', clip_on=False, cbar=False,
                 cmap=sns.color_palette(['aliceblue'], as_cmap=True))
ax.set_xlabel('knots')
ax.set_ylabel('sigma')
plt.tight_layout()
plt.show()

相关问题