matplotlib 逐步向图例添加自定义条目

tpxzln5u  于 2023-05-01  发布在  其他
关注(0)|答案(1)|浏览(127)

我想在这样的多功能中逐步建立一个情节

示例

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np

def plot_on_axis(ax: plt.Axes, x: np.ndarray, y: np.ndarray, color, name) -> plt.Axes:
    ax.plot(x, y, color=color, label="orig")
    ax.plot(x, y + 0.2, "--", color=color, label="shifted")
    patch = mpatches.Patch(color=color, label=name)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles + [patch], labels + [name])
    return ax

def get_fig() -> plt.Figure:
    x1 = np.linspace(0, 3)
    y1 = np.sin(x1)

    x2 = np.linspace(0, 3)
    y2 = np.cos(x2)

    fig = plt.figure()
    ax = fig.subplots()
    plot_on_axis(ax, x1, y1, "tab:blue", "sin")
    plot_on_axis(ax, x2, y2, "tab:orange", "cos")

    return fig

get_fig().show()

问题

但是,这会覆盖图例中的sin条目,因此仅显示cos

因为对get_legend_handles_labels的第二次调用仅返回四个元素,而不是相加的一个(如果它将返回全部,则将存在sin的重复条目)。
有没有办法在plot_on_axis中构建图例,或者应该在get_fig中处理图例?在plot_on_axis中处理它对我来说似乎要优雅得多,除了这个问题。
或者,是否有更好的方式将条目的分组传达给图的查看者?

eoigrqb6

eoigrqb61#

您可以返回一个Artist的列表,而不是返回您根本不使用的Axes示例,然后使用这些美工人员在get_fig中创建自定义图例。

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from typing import List

def plot_on_axis(ax: plt.Axes, x: np.ndarray, y: np.ndarray, color, name) -> List[plt.Artist]:
    line1, = ax.plot(x, y, color=color, label="orig")
    line2, = ax.plot(x, y + 0.2, "--", color=color, label="shifted")
    patch = mpatches.Patch(color=color, label=name)
    return [line1, line2, patch]

def get_fig() -> plt.Figure:
    x1 = np.linspace(0, 3)
    y1 = np.sin(x1)

    x2 = np.linspace(0, 3)
    y2 = np.cos(x2)

    fig = plt.figure()
    ax = fig.subplots()
    handles1 = plot_on_axis(ax, x1, y1, "tab:blue", "sin")
    handles2 = plot_on_axis(ax, x2, y2, "tab:orange", "cos")

    # Create the full legend.
    handles = handles1 + handles2
    labels = [handle.get_label() for handle in handles]
    ax.legend(handles, labels)

    return fig

get_fig().show()

使用ax.legend(handles, labels, ncol=2)可能是进一步分离两组数据的好方法:

相关问题