matplotlib 如何自定义海上热图的颜色条?

niknxzdl  于 2023-05-18  发布在  其他
关注(0)|答案(1)|浏览(184)

背景:我比较了13个模型的性能,通过使用它们中的每一个模型对四个数据集进行预测。现在我有4 * 13个R平方值,表示拟合优度。问题是存在一些较大的负R平方值,使得可视化不那么好。

正的R平方值很难区分,因为这些负值如-11或-9.7。如何通过自定义色带来扩展正范围并压缩负范围?代码和数据如下。

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt

fig, ax = plt.subplots()
data = np.array([[  0.9848,   0.    ,   0.9504,  -0.8198,   0.9501,   0.9071,
          0.8598,   0.9348,   0.    ,   0.713 ,   0.    ,   0.669 ,
          0.6184,   0.    ],
       [  0.9733,   0.    ,   0.0566,  -9.654 ,   0.1291,  -0.0926,
         -0.0661,  -2.3085,   0.    , -10.63  ,   0.    ,  -3.797 ,
         -7.592 ,   0.    ],
       [  0.9676,   0.    ,   0.9331,   0.9177,   0.9401,   0.9352,
          0.9251,   0.7987,   0.    ,   0.5635,   0.    ,   0.5924,
          0.2456,   0.    ],
       [  0.9759,   0.    ,  -0.114 ,   0.1566,   0.0412,   0.3588,
          0.2605,  -0.5471,   0.    ,   0.2534,   0.    ,   0.5216,
          0.3784,   0.    ]])
def comp_heatmap(ax):
    with sns.axes_style('white'):
        ax = sns.heatmap(
            data, ax=ax, vmax=.3,
            annot=True,
            xticklabels=np.arange(14),
            yticklabels=np.arange(4),
        )
    ax.set_xlabel('Model', fontdict=font_text)
    ax.set_ylabel(r'$R^2$', fontproperties=font_formula, labelpad=5)
    ax.figure.colorbar(ax.collections[0])
    # set tick labels
    xticks = ax.get_xticks()
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks.astype(int))
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels(['lnr, fit', 'lg, fit', 'lnr, test', 'lg, test'])

comp_heatmap(ax)
8yparm6h

8yparm6h1#

我使用了FuncNorm方法来解决它。

from matplotlib import pyplot as plt, font_manager as fm, colors

def forward(x):
    x = base ** x - 1
    return x

def inverse(x):
    x = np.log(x + 1) / np.log(base)
    return x

def comp_heatmap(ax):
    plt.rc('font', family='Times New Roman', size=15)
    plt.subplots_adjust(left=0.05, right=1)
    norm = colors.FuncNorm((forward, inverse), vmin=-11, vmax=1)
    mask = np.zeros_like(data)
    mask[:, [1, 8, 10, 13]] = 1
    mask = mask.astype(np.bool)
    with sns.axes_style('white'):
        ax = sns.heatmap(
            data, ax=ax, vmax=.3,
            mask=mask,
            annot=True, fmt='.4',
            annot_kws=font_annot,
            norm=norm,
            xticklabels=np.arange(14),
            yticklabels=np.arange(4),
            cbar=False,
            cmap='rainbow'
        )    
    cbar = ax.figure.colorbar(ax.collections[0])
    cbar.set_ticks([-11, -0.5, 0, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    # set tick labels
    xticks = ax.get_xticks()
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks.astype(int), **font_tick)
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels(['', '', '', ''])
    return ax

font_formula = fm.FontProperties(
    math_fontfamily='cm', size=22
)
font_text = {'size': 22, 'fontfamily': 'Times New Roman'}
font_annot = {'size': 17, 'fontfamily': 'Times New Roman'}
font_tick = {'size': 18, 'fontfamily': 'Times New Roman'}
fig, axes = plt.subplots()
base = 5
ax = comp_heatmap(axes)

相关问题