matplotlib 共享副轴

7qhs6swi  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(108)

如何在matplotlib中使用子图设置共享辅助轴。
下面是显示问题的最小代码:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

def countour_every(ax, every, x_data, y_data,
                   color='black', linestyle='-', marker='o', **kwargs):
    """Draw a line with countour marks at each every points"""
    line, = ax.plot(x_data, y_data, linestyle)
    return line

def prettify_axes(ax, data):
    """Makes my plot pretty"""

    if 'title' in data:
        ax.set_title(data['title'])

    if 'y_lim' in data:
        ax.set_ylim(data['y_lim'])

    if 'x_lim' in data:
        ax.set_xlim(data['x_lim'])

    # Draw legend only if labels were set (HOW TO DO IT?)
    # if ax("has_some_label_set"):
    ax.legend(loc='upper right', prop={'size': 6})

    ax.title.set_fontsize(7)
    ax.xaxis.set_tick_params(labelsize=6)
    ax.xaxis.set_tick_params(direction='in')
    ax.xaxis.label.set_size(7)

    ax.yaxis.set_tick_params(labelsize=6)
    ax.yaxis.set_tick_params(direction='in')
    ax.yaxis.label.set_size(7)

def prettify_second_axes(ax):
    ax.yaxis.set_tick_params(labelsize=7)
    ax.yaxis.set_tick_params(labelcolor='red')
    ax.yaxis.label.set_size(7)

def compare_plot(ax, data):
    line1 = countour_every(ax, 10, **data[0])
    if 'label' in data[0]:
        line1.set_label(data[0]['label'])

    line2 = countour_every(ax, 10, **data[1])
    if 'label' in data[1]:
        line2.set_label(data[1]['label'])

    ax2 = ax.twinx()
    line3 = ax.plot(
            data[0]['x_data'],
            data[0]['y_data']-data[1]['y_data'], '-',
            color='red', alpha=.2, zorder=1)

    prettify_axes(ax, data[0])
    prettify_second_axes(ax2)

d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}

fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)

compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])

fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')

fig.savefig('demo.png', dpi=300)

生成以下图像

我们可以看到,X轴和Y轴是正确共享的,但次双轴在所有子图中重复。
此外,次轴没有正确缩放以适应数据。(这应该独立于主y轴受到限制)。

qlfbtfca

qlfbtfca1#

您将需要share the twin axes manually,并删除ticklabels

def compare_plot(ax, data):
    # ...
    ax2 = ax.twinx()
    # ...
    return ax2

sax1 = compare_plot(axes[0][0], [d0, d1])
sax2 = compare_plot(axes[0][1], [d0, d2])
sax3 = compare_plot(axes[1][0], [d1, d0])
sax4 = compare_plot(axes[1][1], [d3, d2])

for sax in [sax2, sax3, sax4]:
    sax1.get_shared_y_axes().join(sax1, sax)
sax1.autoscale()
for sax in [sax1,sax3]:
    sax.yaxis.set_tick_params(labelright=False)

相关问题