matplotlib 我可以创建AxesSubplot对象,然后将它们添加到Figure示例吗?

jexiocij  于 2023-04-07  发布在  其他
关注(0)|答案(5)|浏览(154)

查看matplotlib文档,将AxesSubplot添加到Figure的标准方法似乎是使用Figure.add_subplot

from matplotlib import pyplot

fig = pyplot.figure()
ax = fig.add_subplot(1,1,1)
ax.hist( some params .... )

我希望能够独立于图形创建类似AxesSubPlot的对象,这样我就可以在不同的图形中使用它们。

fig = pyplot.figure()
histoA = some_axes_subplot_maker.hist( some params ..... )
histoA = some_axes_subplot_maker.hist( some other params ..... )
# make one figure with both plots
fig.add_subaxes(histo1, 211)
fig.add_subaxes(histo1, 212)
fig2 = pyplot.figure()
# make a figure with the first plot only
fig2.add_subaxes(histo1, 111)

这在matplotlib中可能吗?如果可能,我该怎么做?

**更新:**我还没有成功地将轴和图的创建分离,但下面的答案中的示例可以轻松地在新的或旧的图示例中重用以前创建的轴。这可以用一个简单的函数来说明:

def plot_axes(ax, fig=None, geometry=(1,1,1)):
    if fig is None:
        fig = plt.figure()
    if ax.get_geometry() != geometry :
        ax.change_geometry(*geometry)
    ax = fig.axes.append(ax)
    return fig
ibps3vxo

ibps3vxo1#

通常,您只需将axes示例传递给函数。
例如:

import matplotlib.pyplot as plt
import numpy as np

def main():
    x = np.linspace(0, 6 * np.pi, 100)

    fig1, (ax1, ax2) = plt.subplots(nrows=2)
    plot(x, np.sin(x), ax1)
    plot(x, np.random.random(100), ax2)

    fig2 = plt.figure()
    plot(x, np.cos(x))

    plt.show()

def plot(x, y, ax=None):
    if ax is None:
        ax = plt.gca()
    line, = ax.plot(x, y, 'go')
    ax.set_ylabel('Yabba dabba do!')
    return line

if __name__ == '__main__':
    main()

为了回答你的问题,你总是可以这样做:

def subplot(data, fig=None, index=111):
    if fig is None:
        fig = plt.figure()
    ax = fig.add_subplot(index)
    ax.plot(data)

此外,您可以简单地将轴示例添加到另一个图形:

import matplotlib.pyplot as plt

fig1, ax = plt.subplots()
ax.plot(range(10))

fig2 = plt.figure()
fig2.axes.append(ax)

plt.show()

调整它的大小以匹配其他子图“形状”也是可能的,但它很快就会变得比它的价值更麻烦。根据我的经验,对于复杂的情况,只是传递一个图或轴示例(或示例列表)的方法要简单得多。

dwthyt8l

dwthyt8l2#

下面展示了如何将轴从一个图“移动”到另一个图。这是@JoeKington最后一个示例的预期功能,在较新的matplotlib版本中不再工作,因为轴不能同时存在于多个图中。
您首先需要从第一个图形中删除轴,然后将其附加到下一个图形,并给予一些位置。

import matplotlib.pyplot as plt

fig1, ax = plt.subplots()
ax.plot(range(10))
ax.remove()

fig2 = plt.figure()
ax.figure=fig2
fig2.axes.append(ax)
fig2.add_axes(ax)

dummy = fig2.add_subplot(111)
ax.set_position(dummy.get_position())
dummy.remove()
plt.close(fig1)

plt.show()
oo7oh9g9

oo7oh9g93#

对于线图,您可以处理Line2D对象本身:

fig1 = pylab.figure()
ax1 = fig1.add_subplot(111)
lines = ax1.plot(scipy.randn(10))

fig2 = pylab.figure()
ax2 = fig2.add_subplot(111)
ax2.add_line(lines[0])
5sxhfpxr

5sxhfpxr4#

TL;DR部分基于Joe很好的答案。

选项1:fig.add_subplot()

def fcn_return_plot():
    return plt.plot(np.random.random((10,)))
n = 4
fig = plt.figure(figsize=(n*3,2))
#fig, ax = plt.subplots(1, n,  sharey=True, figsize=(n*3,2)) # also works
for index in list(range(n)):
    fig.add_subplot(1, n, index + 1)
    fcn_return_plot()
    plt.title(f"plot: {index}", fontsize=20)

选项2:将ax[index]传递给返回ax[index].plot()的函数

def fcn_return_plot_input_ax(ax=None):
    if ax is None:
        ax = plt.gca()
    return ax.plot(np.random.random((10,)))
n = 4
fig, ax = plt.subplots(1, n,  sharey=True, figsize=(n*3,2))
for index in list(range(n)):
    fcn_return_plot_input_ax(ax[index])
    ax[index].set_title(f"plot: {index}", fontsize=20)

输出方面。x1c 0d1xx 1c 1d 1x
注意:Opt.1 plt.title()在opt.2中更改为ax[index].set_title()。在货车der Plas的书中找到更多Matplotlib Gotchas。

0yycz8jy

0yycz8jy5#

扩展我之前的答案,我们可以返回整个ax,而不仅仅是ax.plot()
如果dataframe有20种类型的100个测试(此处为id):

dfA = pd.DataFrame(np.random.random((100,3)), columns = ['y1', 'y2', 'y3'])
dfB = pd.DataFrame(np.repeat(list(range(20)),5), columns = ['id'])
dfC = dfA.join(dfB)

和情节函数(这是整个答案的关键):

def plot_feature_each_id(df, feature, id_range=[], ax=None, legend_bool=False):
    feature = df[feature]
    if not len(id_range): id_range=set(df['id'])
    legend_arr = []
    for k in id_range:
        pass
        mask = (df['id'] == k)
        ax.plot(feature[mask])
        legend_arr.append(f"id: {k}")
    if legend_bool: ax.legend(legend_arr)
    return ax

我们可以实现:

feature_arr = dfC.drop('id',1).columns
id_range= np.random.randint(len(set(dfC.id)), size=(10,))
n = len(feature_arr)
fig, ax = plt.subplots(1, n,  figsize=(n*6,4));
for i,k in enumerate(feature_arr):
    plot_feature_each_id(dfC, k, np.sort(id_range), ax[i], legend_bool=(i+1==n))
    ax[i].set_title(k, fontsize=20)
    ax[i].set_xlabel("test nr. (id)", fontsize=20)

相关问题