修改gluonts创建的matplotlib图的图例颜色

dtcbnfnu  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(91)

我正在使用gluonts并绘制一个预测(代码来自DeepVaR notebook)。代码如下:

def plot_prob_forecasts(ts_entry, forecast_entry, asset_name, plot_length=20):
    prediction_intervals = (0.95, 0.99)
    legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]
    fig, ax = plt.subplots(1, 1, figsize=(10, 7))
    ts_entry[-plot_length:].plot(ax=ax)  # plot the time series
    forecast_entry.plot( intervals=prediction_intervals, color='g')    
    plt.grid(which="both")
    plt.legend(legend, loc="upper left")
    plt.title(f'Forecast of {asset_name} series Returns')
    plt.show()

并生成以下图:

置信区间图例中的颜色不正确,但我不知道如何修复它们。调用plt.gca().get_legend_handles_labels()仅返回第一行(观察值)。调用legend()之前或之后的输出相同。来自gluonts的代码是:

def plot(
        self,
        *,
        intervals=(0.5, 0.9),
        ax=None,
        color=None,
        name=None,
        show_label=False,
    ):
        """
        Plot median forecast and prediction intervals using ``matplotlib``.

        By default the `0.5` and `0.9` prediction intervals are plotted. Other
        intervals can be choosen by setting `intervals`.

        This plots to the current axes object (via ``plt.gca()``), or to ``ax``
        if provided. Similarly, the color is using matplotlibs internal color
        cycle, if no explicit ``color`` is set.

        One can set ``name`` to use it as the ``label`` for the median
        forecast. Intervals are not labeled, unless ``show_label`` is set to
        ``True``.
        """
        import matplotlib.pyplot as plt

        # Get current axes (gca), if not provided explicitly.
        ax = maybe.unwrap_or_else(ax, plt.gca)

        # If no color is provided, we use matplotlib's internal color cycle.
        # Note: This is an internal API and might change in the future.
        color = maybe.unwrap_or_else(
            color, lambda: ax._get_lines.get_next_color()
        )

        # Plot median forecast
        ax.plot(
            self.index.to_timestamp(),
            self.quantile(0.5),
            color=color,
            label=name,
        )

        # Plot prediction intervals
        for interval in intervals:
            if show_label:
                if name is not None:
                    label = f"{name}: {interval}"
                else:
                    label = interval
            else:
                label = None

            # Translate interval to low and high values. E.g for `0.9` we get
            # `low = 0.05` and `high = 0.95`. (`interval + low + high == 1.0`)
            # Also, higher interval values mean lower confidence, and thus we
            # we use lower alpha values for them.
            low = (1 - interval) / 2
            ax.fill_between(
                # TODO: `index` currently uses `pandas.Period`, but we need
                # to pass a timestamp value to matplotlib. In the future this
                # will use ``zebras.Periods`` and thus needs to be adapted.
                self.index.to_timestamp(),
                self.quantile(low),
                self.quantile(1 - low),
                # Clamp alpha betwen ~16% and 50%.
                alpha=0.5 - interval / 3,
                facecolor=color,
                label=label,
            )

如果我设置color=None,我会从matplotlib得到一个错误。设置show_label=True并传递名称也不起作用。有什么想法如何修复它?
python=3.9.18
matplotlib=3.8.0
胶子=0.13.2

biswetbf

biswetbf1#

plt.legend通常使用图中遇到的“带标签”的matplotlib元素。在本例中,深色绿色区域由两个叠加的透明层组成。默认行为仅单独显示透明层。您可以使用句柄元组显示一个在另一个之上。
下面是一些简化的独立代码来模拟您的情况。

import matplotlib.pyplot as plt
import numpy as np

# create some dummy test data
x = np.arange(30)
y = np.random.normal(0.1, 1, size=x.size).cumsum()

fig, ax = plt.subplots(1, 1, figsize=(10, 7))

# simulate plotting the observations
ax.plot(x, y)

prediction_intervals = (0.95, 0.99)
legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]
color = 'g'

xp = np.arange(10, 30)
yp = y[-xp.size:] + np.random.normal(0.03, 0.2, size=xp.size).cumsum()
# simulate plotting median forecast
ax.plot(xp, yp, color=color, label=None)

# simulate plotting the prediction intervals
for interval in prediction_intervals:
    ax.fill_between(xp, yp - interval ** 2 * 5, yp + interval ** 2 * 5,
                    alpha=0.5 - interval / 3,
                    facecolor=color,
                    label=None)

# the matplotlib elements for the two curves
handle_l1, handle_l2 = ax.lines[:2]
# the matplotlib elements for the two filled areas
handle_c1, handle_c2 = ax.collections[:2]

ax.legend(handles=[handle_l1, handle_l2, handle_c1, (handle_c1, handle_c2)],
          labels=legend, loc="upper left")
ax.grid(which="both")

plt.tight_layout()
plt.show()

相关问题