matplotlib 如何将堆叠x轴标签添加到堆叠条形图

yruzcnhs  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(110)

给定以下代码:

import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby
import pandas as pd

percent_EVs = [0, 8, 21, 26, 37, 39, 41, 75, 95, 97]
percent_DVs = [100, 92, 79, 74, 63, 61, 59, 25, 5, 3]
num_buses = [1423, 1489, 1613, 1606, 1710, 1684, 1694, 2153, 2202, 2195]
veh_range = ['DV only', 60, 120, 150, 60, 120, 150, 60, 120, 150]
deployment = ['DV only', 'Low', 'Medium', 'High', 'Low', 'Medium', 'High', 'Low', 'Medium', 'High']

df = pd.DataFrame({'Percent EVs': percent_EVs, 'Percent DVs': percent_DVs,
                   '# Buses': num_buses, 'Range (mi)':veh_range, 
                   'Deployment target': deployment})

df.set_index('Range (mi)', inplace=True)

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

fig = plt.figure()
ax = fig.add_subplot(111)
#Your df.plot code with ax parameter here
df.plot.bar(stacked=True, rot=0, alpha=0.5, legend=False, ax=fig.gca())

labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.2*df.index.nlevels, left=0.1*df.index.nlevels)
plt.show()

当前

结果图与所需的输出相差甚远。如何至少更改x轴标签以模拟所需的输出?

所需x轴

看这里,我试图创建以下图表:

hm2xizp9

hm2xizp91#

以下是您正在寻找的解决方案:

import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.font_manager as font_manager
import numpy as np

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos*1.5],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

percent_EVs = [0, 8, 21, 26, 37, 39, 41, 75, 95, 97]
percent_DVs = [100, 92, 79, 74, 63, 61, 59, 25, 5, 3]
num_buses = [1423, 1489, 1613, 1606, 1710, 1684, 1694, 2153, 2202, 2195]
veh_range = ["DV", 60, 120, 150, 60, 120, 150, 60, 120, 150]
deployment = ["", "", "Low", "", "", "Medium", "", "", "High", ""]

df = pd.DataFrame(
    {
        "Percent EVs": percent_EVs,
        "Percent DVs": percent_DVs,
        "# Buses": num_buses,
        "Range (mi)": veh_range,
        "Deployment target": deployment,
    }
)

fig, ax1 = plt.subplots(figsize=(8, 5))

colors = ["#A9D18E", "#9cc2e5"]

x = df.index.values

fontsize = 20
font = {'family':'Times New Roman', 'size': fontsize}
font_leg = font_manager.FontProperties(family='Times New Roman',
                                size=fontsize)

ax1.bar(
    x,
    df["Percent EVs"],
    color=colors[0],
    label=f"% EVs",
)
ax1.bar(
    x,
    df["Percent DVs"],
    color=colors[1],
    bottom=df["Percent EVs"],
    label=f"% DVs",
)
ax1.set_xlabel("Range (mi) / Deployment Target", fontdict=font)
ax1.set_ylabel("% Share", fontdict=font, color="black")
plt.yticks(fontsize=fontsize, fontname = "Times New Roman")

for c in ax1.containers:
    labels = [v if v > 5 else "" for v in c.datavalues]
    ax1.bar_label(c, labels=labels, label_type="center", font=font, color="white")

custom_ticks = [
    f"{row['Range (mi)']}\n{row['Deployment target']}"
    for _, row in df.iterrows()
]
ax1.set_xticks(x)
ax1.set_xticklabels(custom_ticks, fontdict=font)

ax2 = ax1.twinx()
ax2.set_ylabel("# Buses", color="black", fontdict=font)

ax2.plot(
    x,
    df["# Buses"],
    linestyle="-",
    color="k",
    label="# Buses",
)

plt.yticks(np.arange(0, 3100, 500), fontsize=fontsize, fontname = "Times New Roman")
ax1.set_ylim(0, 102)

add_line(ax1, 0*0.1, -0.1)
add_line(ax1, 1.3*0.1, -0.1)
add_line(ax1, 4.07*0.1, -0.1)
add_line(ax1, 6.86*0.1, -0.1)
add_line(ax1, 10*0.1, -0.1)

lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2,loc='upper left', fontsize = fontsize, facecolor='white', framealpha = 1,
           fancybox=True, shadow=False, ncol=2, numpoints = 1, prop = font_leg, columnspacing=0.4, handletextpad=0.2)

plt.tight_layout()
plt.show()

相关问题