matplotlib pyplot在图例中合并多个线标签

k5hmc34c  于 2023-08-06  发布在  其他
关注(0)|答案(9)|浏览(140)

我的数据会导致绘制多条线,我想在图例中为这些线提供一个标签。我想用下面的例子可以更好地说明这一点,

a = np.array([[ 3.57,  1.76,  7.42,  6.52],
              [ 1.57,  1.2 ,  3.02,  6.88],
              [ 2.23,  4.86,  5.12,  2.81],
              [ 4.48,  1.38,  2.14,  0.86],
              [ 6.68,  1.72,  8.56,  3.23]])

plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')

plt.legend(loc='best')

字符串
正如你在Out[23]中看到的,该图产生了5条不同的线。生成的图如下所示:

有没有什么方法可以让plot方法避免多个标签?我不想尽可能多地使用自定义图例(在那里你一次指定标签和线条形状)。

h43kikqp

h43kikqp1#

如果我打算经常做的话,我会自己做一个小的辅助函数;

from matplotlib import pyplot
import numpy

a = numpy.array([[ 3.57,  1.76,  7.42,  6.52],
                 [ 1.57,  1.2 ,  3.02,  6.88],
                 [ 2.23,  4.86,  5.12,  2.81],
                 [ 4.48,  1.38,  2.14,  0.86],
                 [ 6.68,  1.72,  8.56,  3.23]])

def plotCollection(ax, xs, ys, *args, **kwargs):

  ax.plot(xs,ys, *args, **kwargs)

  if "label" in kwargs.keys():

    #remove duplicates
    handles, labels = pyplot.gca().get_legend_handles_labels()
    newLabels, newHandles = [], []
    for handle, label in zip(handles, labels):
      if label not in newLabels:
        newLabels.append(label)
        newHandles.append(handle)

    pyplot.legend(newHandles, newLabels)

ax = pyplot.subplot(1,1,1)  
plotCollection(ax, a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
plotCollection(ax, a[:,1::2].T, a[:, ::2].T, 'b', label='data_b')
pyplot.show()

字符串
handleslabels的图例中删除重复项(比您现有的)的一种更简单(也更清晰)的方法是:

handles, labels = pyplot.gca().get_legend_handles_labels()
newLabels, newHandles = [], []
for handle, label in zip(handles, labels):
  if label not in newLabels:
    newLabels.append(label)
    newHandles.append(handle)
pyplot.legend(newHandles, newLabels)

rekjcdws

rekjcdws2#

Numpy的解决方案基于上面Will的回答。

import numpy as np
import matplotlib.pylab as plt
a = np.array([[3.57, 1.76, 7.42, 6.52],
              [1.57, 1.20, 3.02, 6.88],
              [2.23, 4.86, 5.12, 2.81],
              [4.48, 1.38, 2.14, 0.86],
              [6.68, 1.72, 8.56, 3.23]])

plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
handles, labels = plt.gca().get_legend_handles_labels()

字符串
假设相等的标签有相等的句柄,得到唯一的标签和它们各自的索引,它们对应于句柄索引。

labels, ids = np.unique(labels, return_index=True)
handles = [handles[i] for i in ids]
plt.legend(handles, labels, loc='best')
plt.show()

new9mtju

new9mtju3#

Matplotlib为您提供了一个很好的行集合接口LineCollection。代码很简单

import numpy
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection

a = numpy.array([[ 3.57,  1.76,  7.42,  6.52],
                 [ 1.57,  1.2 ,  3.02,  6.88],
                 [ 2.23,  4.86,  5.12,  2.81],
                 [ 4.48,  1.38,  2.14,  0.86],
                 [ 6.68,  1.72,  8.56,  3.23]])

xs = a[:,::2]
ys = a[:, 1::2]
lines = LineCollection([list(zip(x,y)) for x,y in zip(xs, ys)], label='data_a')
f, ax = plt.subplots(1, 1)
ax.add_collection(lines)
ax.legend()
ax.set_xlim([xs.min(), xs.max()]) # have to set manually
ax.set_ylim([ys.min(), ys.max()])
plt.show()

字符串
这将产生以下输出:

的数据

tzdcorbm

tzdcorbm4#

一个低技术的解决方案是做两个情节调用。一个绘制您的数据,另一个只绘制句柄:

a = np.array([[ 3.57,  1.76,  7.42,  6.52],
              [ 1.57,  1.2 ,  3.02,  6.88],
              [ 2.23,  4.86,  5.12,  2.81],
              [ 4.48,  1.38,  2.14,  0.86],
              [ 6.68,  1.72,  8.56,  3.23]])

plt.plot(a[:,::2].T, a[:, 1::2].T, 'r')
plt.plot([],[], 'r', label='data_a')

plt.legend(loc='best')

字符串
结果如下:


的数据

hfsqlsce

hfsqlsce5#

因此,使用will的建议和另一个问题here,我在这里留下我的补救措施

handles, labels = plt.gca().get_legend_handles_labels()
i =1
while i<len(labels):
    if labels[i] in labels[:i]:
        del(labels[i])
        del(handles[i])
    else:
        i +=1

plt.legend(handles, labels)

字符串
新的图看起来像,

5f0d552i

5f0d552i6#

最简单和最pythonic的方法来删除重复是使用一个字典的键,这是保证是唯一的。这也确保了我们只对每个(handle,label)对迭代一次。

handles, labels = plt.gca().get_legend_handles_labels()

# labels will be the keys of the dict, handles will be values
temp = {k:v for k,v in zip(labels, handles)}

plt.legend(temp.values(), temp.keys(), loc='best')

字符串

jmo0nnb3

jmo0nnb37#

我会做这个把戏:

for i in range(len(a)):
  plt.plot(a[i,::2].T, a[i, 1::2].T, 'r', label='data_a' if i==0 else None)

字符串

ig9co6j1

ig9co6j18#

我找到了一个简单的方法来解决这个问题:

a = np.array([[ 3.57,  1.76,  7.42,  6.52],
              [ 1.57,  1.2 ,  3.02,  6.88],
              [ 2.23,  4.86,  5.12,  2.81],
              [ 4.48,  1.38,  2.14,  0.86],
              [ 6.68,  1.72,  8.56,  3.23]])

p1=plt.plot(a[:,::2].T, a[:, 1::2].T, color='r')
plt.legend([p1[0]],['data_a'],loc='best')

字符串

iyr7buue

iyr7buue9#

如果提供“无”作为标签,则不打印标签。你就这么做吧

a = np.array([[ 3.57,  1.76,  7.42,  6.52],
              [ 1.57,  1.2 ,  3.02,  6.88],
              [ 2.23,  4.86,  5.12,  2.81],
              [ 4.48,  1.38,  2.14,  0.86],
              [ 6.68,  1.72,  8.56,  3.23]])

plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label=['data_a', None, None, None, None])

plt.legend(loc='best')

字符串
Plot From The Code Above
一个更懒惰的选择:

a = np.array([[ 3.57,  1.76,  7.42,  6.52],
              [ 1.57,  1.2 ,  3.02,  6.88],
              [ 2.23,  4.86,  5.12,  2.81],
              [ 4.48,  1.38,  2.14,  0.86],
              [ 6.68,  1.72,  8.56,  3.23]])

n_labels = 5
labels = [None for _ in range(n_labels)]
labels[0] = 'data_a'
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label=labels)

plt.legend(loc='best')

相关问题