matplotlib 如何使曲线表示轴的未显示部分

deikduxw  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(99)

我想尝试复制这张图表的格式(格式,不一定是内容)。

我找到了一个关于如何做到这一点的教程here,它有下面的代码(数据集here
但是它没有在轴的左边的曲线,代表轴的那部分被跳过。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# This makes out plots higher resolution, which makes them easier to see while building
plt.rcParams['figure.dpi'] = 100

# import data
gdp = pd.read_csv('gdp_1960_2020.csv')

gdp_dumbbell = gdp[(gdp['country'].isin(countries)) & ((gdp['year'] == 1960) | (gdp['year'] == 2020))].sort_values(by='gdp')

# Setup plot size.
fig, ax = plt.subplots(figsize=(7,4))

# Create grid 
# Zorder tells it which layer to put it on. We are setting this to 1 and our data to 2 so the grid is behind the data.
ax.grid(which="major", axis='both', color='#758D99', alpha=0.6, zorder=1)

# Remove splines. Can be done one at a time or can slice with a list.    
ax.spines[['top','right','bottom']].set_visible(False)

# Setup data
gdp_dumbbell = (gdp[(gdp['country'].isin(countries)) & ((gdp['year'] == 2000) | (gdp['year'] == 2020))][['year','gdp_trillions','country']]
            .pivot(index='country',columns='year', values='gdp_trillions')
            .sort_values(by=2020))

# Plot data
# Plot horizontal lines first
ax.hlines(y=gdp_dumbbell.index, xmin=gdp_dumbbell[2000], xmax=gdp_dumbbell[2020], color='#758D99', zorder=2, linewidth=2, label='_nolegend_', alpha=.8)

# Plot bubbles next
ax.scatter(gdp_dumbbell[2000], gdp_dumbbell.index, label='2000', s=60, color='#DB444B', zorder=3)
ax.scatter(gdp_dumbbell[2020], gdp_dumbbell.index, label='2020', s=60, color='#006BA2', zorder=3)

# Set xlim
ax.set_xlim(0, 25.05)

# Reformat x-axis tick labels
ax.xaxis.set_tick_params(labeltop=True,      # Put x-axis labels on top
                     labelbottom=False,  # Set no x-axis labels on bottom
                     bottom=False,       # Set no ticks on bottom
                     labelsize=11,       # Set tick label size
                     pad=-1)             # Lower tick labels a bit

# Reformat y-axis tick labels
ax.set_yticklabels(gdp_dumbbell.index,       # Set labels again
               ha = 'left')              # Set horizontal alignment to left
ax.yaxis.set_tick_params(pad=100,            # Pad tick labels so they don't go over y-axis
                     labelsize=11,       # Set label size
                     bottom=False)       # Set no ticks on bottom/left

# Set Legend
ax.legend(['2000', '2020'], loc=(-.29,1.09), ncol=2, frameon=False, handletextpad=-.1, handleheight=1)

# Add in line and tag
ax.plot([-0.08, .9],                 # Set width of line
    [1.17, 1.17],                # Set height of line
    transform=fig.transFigure,   # Set location relative to plot
    clip_on=False, 
    color='#E3120B', 
    linewidth=.6)

ax.add_patch(plt.Rectangle((-0.08,1.17),               # Set location of rectangle by lower left corder
                       0.05,                       # Width of rectangle
                       -0.025,                      # Height of rectangle. Negative so it goes down.
                       facecolor='#E3120B', 
                       transform=fig.transFigure, 
                       clip_on=False, 
                       linewidth = 0))

# Add in title and subtitle
ax.text(x=-0.08, y=1.09, s="Great expectations", transform=fig.transFigure, ha='left', fontsize=13, weight='bold', alpha=.8)

ax.text(x=-0.08, y=1.04, s="Top 9 countries by GDP, in trillions of USD", transform=fig.transFigure, ha='left', fontsize=11, alpha=.8)

# Set source text
ax.text(x=-0.08, y=0.04, s="""Source: "GDP of all countries(1960-2020)" via Kaggle.com""", transform=fig.transFigure, ha='left', fontsize=9, alpha=.7)

plt.show()

这会产生一个类似的图表。

但它没有弯曲的线条。我怎么才能弄到那些?
图表与图像的不同之处不止于此(它使用不同的数据集,其格式在几个方面不同)。

eni9jsuy

eni9jsuy1#

您可以手动调整它:

# I'm showing only the lines updated and/or added

ax.set_xlim(-2, 25.05)

ax.grid(which='major', axis='x', color='#758D99', alpha=0.6, zorder=1)

ax.spines[['left', 'top','right', 'bottom']].set_visible(False)

for y in ax.get_yticks():
    ax.plot(
        [-2, -1.6, -1.4, -1.2, -1, ax.get_xticks()[:-1].max()],
        [y, y, y+.3, y-.3, y, y], color='#758D99', alpha=0.3
    )

plt.show();

顺便说一下,你的代码是不完全可复制的,有两行缺失:

gdp['gdp_trillions'] = gdp['gdp'] / 1_000_000_000_000

countries = (gdp[gdp['year'] == 2020].sort_values(
            by='gdp_trillions')[-9:]['country'].values)

相关问题