matplotlib 如何使用gridspec创建双轴

72qzrwbm  于 2023-04-07  发布在  其他
关注(0)|答案(1)|浏览(185)

我使用两个GridSpec对象生成以下图:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect

plt.rcParams.update({'figure.figsize' : (10,10)})
plt.rcParams.update({'font.size' : 18})

plt.rcParams.update({'mathtext.fontset': 'stix'})
plt.rcParams.update({'font.family': 'STIXGeneral'})

我的代码:

n_rows = 3
n_cols = 3

gs_right = plt.GridSpec(n_rows,n_cols, left=0.25,right=0.95,hspace=0)
gs_left = plt.GridSpec(n_rows,n_cols,left=0,right=0.6,wspace=0, hspace=0)

w, h = figaspect(2.4/3)
fig = plt.figure(figsize=(w,h),constrained_layout=False)

#Generate all the subplots
rightaxs = [fig.add_subplot(gs_right[i,1]) for i in range(3)]
leftaxs = [fig.add_subplot(gs_left[i,j]) for i in range(3) for j in range(2)]

#Colorbar params
width_cbar = 0.015
bottom_cbar = 0.01
leftax_xshift = 0.001
cmap = "Reds"

for ax in leftaxs:
    if ax.is_first_row():
        if ax.is_first_col():
            ax.set_title("Model 1")
        else:
            ax.set_title("Model 2")
    if ax.is_first_col():
        vmin = 0
        vmax = 10
        im = ax.hist2d(np.arange(-3,3,0.5),np.arange(-3,3,0.5), range=[[-15,15],[-15,15]], vmin=vmin,vmax=vmax,cmap=cmap)
    else:
        im = ax.hist2d(np.arange(-15,15,1),np.arange(-15,15,1), range=[[-15,15],[-15,15]], vmin=vmin,vmax=vmax,cmap=cmap)
        cax = fig.add_axes([ax.get_position().x1+leftax_xshift,ax.get_position().y0,width_cbar,ax.get_position().height])
        cax.tick_params(direction="in", length=6)
        fig.colorbar(im[3], cax=cax, extend='both', label=r'$N$')
    
for ax in rightaxs:
    if ax.is_first_row():
        ax.set_title("Model 3")
    vmin = 0
    vmax = 2
    im = ax.hist2d(np.arange(-13,13,1),np.arange(13,-13,-1),range=[[-15,15],[-15,15]],vmin=vmin,vmax=vmax,cmap=cmap)
    cax = fig.add_axes([ax.get_position().x1,ax.get_position().y0,width_cbar,ax.get_position().height])
    cax.tick_params(direction="in", length=6)
    fig.colorbar(im[3], cax=cax, extend='both', label=r'$\Delta N$')
    ax.set_aspect('equal')

#Axes ticks and labels----------------------------------------------------------------------
lb_ticks = np.linspace(-12,12,5)
tick_length = 6

for ax in leftaxs:
    ax.set_xticks(lb_ticks)
    ax.set_yticks(lb_ticks)
    ax.tick_params(direction="in", length=tick_length)
    second_axy = ax.twinx()
    second_axy.set_ylim([-15,15])
    second_axy.set_yticks(lb_ticks)
    second_axy.set_yticklabels([])
    second_axy.tick_params(direction='in', length=tick_length)
    second_axx = ax.twiny()
    second_axx.set_xlim([-15,15])
    second_axx.set_xticks(lb_ticks)
    second_axx.set_xticklabels([])
    second_axx.tick_params(direction='in', length=tick_length)
    
    if not ax.is_last_row():
        ax.set_xticklabels([])
    else:
        ax.set_xlabel(r"$\ell$ [°]")
        
    if not ax.is_first_col():
        ax.set_yticklabels([])
    else:
        ax.set_ylabel(r"$b$ [°]")    

for ax in rightaxs:
    ax.set_xticks(lb_ticks)
    ax.set_yticks(lb_ticks)
    ax.tick_params(direction='in',length=tick_length)
    second_axy = ax.twinx()
    second_axy.set_ylim([-15,15])
    second_axy.set_yticks(lb_ticks)
    second_axy.set_yticklabels([])
    second_axy.tick_params(direction='in', length=tick_length)
    #second_axx = ax.twiny()
    #second_axx.set_xlim([-15,15])
    #second_axx.set_xticks(lb_ticks)
    #second_axx.set_xticklabels([])
    #second_axx.tick_params(direction='in', length=tick_length)
    ax.set_yticklabels([])
    
    if not ax.is_last_row():
        ax.set_xticklabels([])
    else:
        ax.set_xlabel(r"$\ell$ [°]")

plt.show()

我在结尾处遇到了以下几行(已注解)的问题:

second_axx = ax.twiny()
second_axx.set_xlim([-15,15])
second_axx.set_xticks(lb_ticks)
second_axx.set_xticklabels([])
second_axx.tick_params(direction='in', length=tick_length)

就像在左边块的子图中一样,我希望上面的行在右边子图的顶部产生刻度。然而,取消注解这些行反而会导致以下错误,我不明白:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~\anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
    339                 pass
    340             else:
--> 341                 return printer(obj)
    342             # Finally look for special method names
    343             method = get_real_method(obj, self.print_method)

~\anaconda3\lib\site-packages\IPython\core\pylabtools.py in <lambda>(fig)
    246 
    247     if 'png' in formats:
--> 248         png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))
    249     if 'retina' in formats or 'png2x' in formats:
    250         png_formatter.for_type(Figure, lambda fig: retina_figure(fig, **kwargs))

~\anaconda3\lib\site-packages\IPython\core\pylabtools.py in print_figure(fig, fmt, bbox_inches, **kwargs)
    130         FigureCanvasBase(fig)
    131 
--> 132     fig.canvas.print_figure(bytes_io, **kw)
    133     data = bytes_io.getvalue()
    134     if fmt == 'svg':

~\anaconda3\lib\site-packages\matplotlib\backend_bases.py in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2191                            else suppress())
   2192                     with ctx:
-> 2193                         self.figure.draw(renderer)
   2194 
   2195                     bbox_inches = self.figure.get_tightbbox(

~\anaconda3\lib\site-packages\matplotlib\artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     39                 renderer.start_filter()
     40 
---> 41             return draw(artist, renderer, *args, **kwargs)
     42         finally:
     43             if artist.get_agg_filter() is not None:

~\anaconda3\lib\site-packages\matplotlib\figure.py in draw(self, renderer)
   1838                 ax.apply_aspect(pos)
   1839             else:
-> 1840                 ax.apply_aspect()
   1841 
   1842             for child in ax.get_children():

~\anaconda3\lib\site-packages\matplotlib\axes\_base.py in apply_aspect(self, position)
   1659         # Not sure whether we need this check:
   1660         if shared_x and shared_y:
-> 1661             raise RuntimeError("adjustable='datalim' is not allowed when both "
   1662                                "axes are shared")
   1663 

RuntimeError: adjustable='datalim' is not allowed when both axes are shared

有什么建议吗?我也很感激任何关于如何提高代码健壮性的评论。我觉得它有点硬编码,因为手动设置:

  • plt.GridSpec()中的参数leftright
  • 纵横比为figaspect
  • 颜色条轴的参数。
jecbmhm3

jecbmhm31#

我意识到我可以完全避免生成双轴来在顶部和右侧设置刻度,只需执行以下操作:

plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True

我的代码中有很大一部分也通过建立设置得到了简化

plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.size'] = 6
plt.rcParams['ytick.major.size'] = 6

相关问题