matplotlib 如何在线图的线上绘制标签?

ig9co6j1  于 2023-08-06  发布在  其他
关注(0)|答案(2)|浏览(122)

我想在matplotlib中将标签绘制在线图的一条线上。

最小示例

#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.plt.plot(y, label=str(i))
sns.plt.legend()
sns.plt.show()

字符串
生成


的数据
相反,我更喜欢像


rwqw0loc

rwqw0loc1#

也许有点古怪,但这能解决你的问题吗?

#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    p = sns.plt.plot(y, label=str(i))
    color = p[0].get_color()
    for x in [250, 500, 750]:
        y2 = y[x]
        sns.plt.plot(x, y2, 'o', color='white', markersize=9)
        sns.plt.plot(x, y2, 'k', marker="$%s$" % str(i), color=color,
                     markersize=7)
sns.plt.legend()
sns.plt.show()

字符串
下面是我得到的结果:
x1c 0d1x的数据

**编辑:**我又想了想,想出了一个解决方案,它会自动尝试为标签找到最佳位置,以避免标签被定位在x值处,两条线彼此非常接近(例如,这可能会导致标签被定位在x值处)。导致标签之间的重叠):

#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

# -----------------------------------------------------------------------------

def inline_legend(lines, n_markers=1):
    """
    Take a list containing the lines of a plot (typically the result of 
    calling plt.gca().get_lines()), and add the labels for those lines on the
    lines themselves; more precisely, put each label n_marker times on the 
    line. 
    [Source of problem: https://stackoverflow.com/q/43573623/4100721]
    """

    import matplotlib.pyplot as plt
    from scipy.interpolate import interp1d
    from math import fabs

    def chunkify(a, n):
        """
        Split list a into n approximately equally sized chunks and return the 
        indices (start/end) of those chunks.
        [Idea: Props to http://stackoverflow.com/a/2135920/4100721 :)]
        """
        k, m = divmod(len(a), n)
        return list([(i * k + min(i, m), (i + 1) * k + min(i + 1, m)) 
                     for i in range(n)])

    # Calculate linear interpolations of every line. This is necessary to 
    # compare the values of the lines if they use different x-values
    interpolations = [interp1d(_.get_xdata(), _.get_ydata()) 
                      for _ in lines]

    # Loop over all lines
    for idx, line in enumerate(lines):

        # Get basic properties of the current line
        label = line.get_label()
        color = line.get_color()
        x_values = line.get_xdata()
        y_values = line.get_ydata()

        # Get all lines that are not the current line, as well as the
        # functions that are linear interpolations of them
        other_lines = lines[0:idx] + lines[idx+1:]
        other_functions = interpolations[0:idx] + interpolations[idx+1:]

        # Split the x-values in chunks to get regions in which to put 
        # labels. Creating 3 times as many chunks as requested and using only
        # every third ensures that no two labels for the same line are too
        # close to each other.
        chunks = list(chunkify(line.get_xdata(), 3*n_markers))[::3]

        # For each chunk, find the optimal position of the label
        for chunk_nr in range(n_markers):

            # Start and end index of the current chunk
            chunk_start = chunks[chunk_nr][0]
            chunk_end = chunks[chunk_nr][1]

            # For the given chunk, loop over all x-values of the current line,
            # evaluate the value of every other line at every such x-value,
            # and store the result.
            other_values = [[fabs(y_values[int(x)] - f(x)) for x in 
                             x_values[chunk_start:chunk_end]]
                            for f in other_functions]

            # Now loop over these values and find the minimum, i.e. for every
            # x-value in the current chunk, find the distance to the closest
            # other line ("closest" meaning abs_value(value(current line at x)
            # - value(other lines at x)) being at its minimum)
            distances = [min([_ for _ in [row[i] for row in other_values]]) 
                         for i in range(len(other_values[0]))]

            # Now find the value of x in the current chunk where the distance
            # is maximal, i.e. the best position for the label and add the
            # necessary offset to take into account that the index obtained
            # from "distances" is relative to the current chunk
            best_pos = distances.index(max(distances)) + chunks[chunk_nr][0]

            # Short notation for the position of the label
            x = best_pos
            y = y_values[x]

            # Actually plot the label onto the line at the calculated position
            plt.plot(x, y, 'o', color='white', markersize=9)
            plt.plot(x, y, 'k', marker="$%s$" % label, color=color,
                     markersize=7)

# -----------------------------------------------------------------------------

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.plt.plot(y, label=str(i))

inline_legend(plt.gca().get_lines(), n_markers=3)
sns.plt.show()


此解决方案的示例输出(注意标签的x位置如何不再完全相同):

如果想要避免使用scipy.interpolate.interp1d,可以考虑这样的解决方案,即对于给定的线A的x值,找到最接近的线B的x值。我认为这可能是有问题的,虽然如果线使用非常不同和/或稀疏的网格?

tktrz96b

tktrz96b2#

这个问题已经很老了,但你可以用matplotlib-labellines包来做:

import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.lineplot(y, label=str(i), legend=False)

labelLines(plt.gca().get_lines())

字符串

相关问题