matplotlib 为多个axvspan显示不同颜色的图例

iq3niunx  于 2023-03-19  发布在  其他
关注(0)|答案(1)|浏览(222)

如何在图例中为不同的axvspan绘制不同的颜色?
我不理解matplotlib的行为:

def plot_test_results(df, c, t_start, t_end):
    t_start = [datetime.strptime(t, '%Y-%m-%d') for t in t_start]
    t_end = [datetime.strptime(t, '%Y-%m-%d') for t in t_end]
    for t1, t2 in zip(t_start, t_end):
        gs = gridspec.GridSpec(2, 1, height_ratios=[2.5,1])
        ax=plt.subplot(gs[0])
        y_start = list(df[t1:t2][df.loc[t1:t2, 'y_pred'].diff(-1) < 0].index)
        y_end = list(df[t1:t2][df.loc[t1:t2, 'y_pred'].diff(-1) > 0].index)
        crash_st = list(filter(lambda x: x > t1 and x < t2, c['crash_st']))
        crash_end = list(filter(lambda x: x > t1 and x < t2, c['crash_end']))
        plt.plot(df['price'][t1:t2], color='blue') 
        [plt.axvspan(x1, x2, alpha=0.4, color='orange', label='prediction', zorder=2) for x1, x2 in zip(y_start, y_end)]
        [plt.axvspan(c1, c2, alpha=0.8, color='red', label='crashes') for c1, c2 in zip(crash_st, crash_end)]
        plt.legend(['Price', 'Crash', 'Crash Prediction'])
        plt.title(test_data + ' ' + model_name +  ', Time period: ' + str(calendar.month_name[t1.month]) + ' ' + str(t1.year) + ' - ' +\
                 str(calendar.month_name[t2.month]) + ' ' + str(t2.year))
        plt.show()

我得到的结果是:

谁能给我解释一下,我怎么能在图例中用红色表示碰撞,用橙子表示碰撞预测?
先谢谢你了,我已经纠结了一段时间了。

vngu2lb8

vngu2lb81#

要在图例中为“崩溃”和“崩溃预测”标签显示不同的颜色,可以创建两个具有所需颜色和标签的单独Line2D对象,然后将这些对象传递给图例函数。

from matplotlib.lines import Line2D

def plot_test_results(df, c, t_start, t_end):
    t_start = [datetime.strptime(t, '%Y-%m-%d') for t in t_start]
    t_end = [datetime.strptime(t, '%Y-%m-%d') for t in t_end]
    for t1, t2 in zip(t_start, t_end):
        gs = gridspec.GridSpec(2, 1, height_ratios=[2.5,1])
        ax=plt.subplot(gs[0])
        y_start = list(df[t1:t2][df.loc[t1:t2, 'y_pred'].diff(-1) < 0].index)
        y_end = list(df[t1:t2][df.loc[t1:t2, 'y_pred'].diff(-1) > 0].index)
        crash_st = list(filter(lambda x: x > t1 and x < t2, c['crash_st']))
        crash_end = list(filter(lambda x: x > t1 and x < t2, c['crash_end']))
        plt.plot(df['price'][t1:t2], color='blue')
        prediction_patches = [plt.axvspan(x1, x2, alpha=0.4, color='orange', zorder=2) for x1, x2 in zip(y_start, y_end)]
        crash_patches = [plt.axvspan(c1, c2, alpha=0.8, color='red') for c1, c2 in zip(crash_st, crash_end)]
        prediction_patch = Line2D([0], [0], color='orange', alpha=0.4, lw=4, label='Crash Prediction')
        crash_patch = Line2D([0], [0], color='red', alpha=0.8, lw=4, label='Crash')
        plt.legend(handles=[prediction_patch, crash_patch])
        plt.title(test_data + ' ' + model_name +  ', Time period: ' + str(calendar.month_name[t1.month]) + ' ' + str(t1.year) + ' - ' +\
                 str(calendar.month_name[t2.month]) + ' ' + str(t2.year))
        plt.show()

相关问题