matplotlib 向海运FacetGrid添加自定义误差条以确认色调和分类x顺序

yh2wf1be  于 2023-03-09  发布在  其他
关注(0)|答案(1)|浏览(131)

我有一个具有预先计算的平均值和标准差的数据集。这些值取决于三个不同的分类值。我希望创建两个条形图,以便在它们之间拆分第一个分类变量。其他两个分类值应在x轴上使用不同的颜色分隔。
seaborn方面,我想创建基于分类xseaborn.catplot条形图,并自定义order以及huecol参数,同时能够添加我自己的自定义标准差。
下面的代码非常直接地给出了条形图的平均值:

import seaborn as sns
import matplotlib.pyplot as plt

tips = sns.load_dataset("tips")

tip_sumstats = (tips.groupby(["day", "sex", "smoker"])
                     .total_bill
                     .agg(["mean", 'sem'])
                     .reset_index())

sns.catplot(
    data=tip_sumstats,
    x="day",
    order=["Sun", "Thur", "Fri", "Sat"],
    y="mean",
    hue="smoker",
    col="sex",
    kind="bar",
    height=4,
)

This answer解决了不涉及hueorder时的问题。

def errplot(x, y, yerr, **kwargs):
    ax = plt.gca()
    data = kwargs.pop("data")
    data.plot(x=x, y=y, yerr=yerr, kind="bar", ax=ax, **kwargs)

g = sns.FacetGrid(tip_sumstats, col="sex", hue="smoker", height=4)
g.map_dataframe(errplot, "day", "mean", "sem")


中的结果
我不知道如何修改这个版本,使它遵守由某个order参数定义的x轴上的分类顺序,而且,我不知道如何向它添加一个dodge=True,使不同颜色的条彼此相邻。
This question试图解决类似的问题,然而,方法非常技术化,一点也不直接,对我来说,没有直接的解决方案存在似乎很奇怪。

xqnpmsa8

xqnpmsa81#

Seaborn不支持这种开箱即用的方式,可能是因为误差条的许多选项都很复杂,难以适应参数传递的方式。
对于您的具体情况,您可以按如下方式计算位置:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

tips = sns.load_dataset("tips")
tip_sumstats = (tips.groupby(["day", "sex", "smoker"])
                .total_bill
                .agg(["mean", 'sem'])
                .reset_index())

def errplot(x, y, data, order, hue, yerr, palette='deep', color=None):
    xs = np.arange(len(order))
    hues = data[hue].unique()
    dodge_width = 0.8
    dodge_vals = np.linspace(-dodge_width / 2, dodge_width / 2, len(hues)*2+1)[1::2]
    colors = sns.color_palette(palette, len(hues))
    for hue_val, dodge_val, color in zip(hues, dodge_vals, colors):
        ys = [data[(data[x] == xi) & (data[hue] == hue_val)][y].to_numpy()[0] for xi in order]
        yerrs = [data[(data[x] == xi) & (data[hue] == hue_val)][yerr].to_numpy()[0] for xi in order]
        plt.bar(x=xs + dodge_val, height=ys, yerr=yerrs, width=dodge_width / len(hues), color=color, label=hue_val)
    plt.xticks(xs, order)

g = sns.FacetGrid(tip_sumstats, col="sex", height=4)
g.map_dataframe(errplot, "day", "mean", hue="smoker", yerr="sem", order=["Sun", "Thur", "Fri", "Sat"])
g.fig.legend(*g.axes.flat[-1].get_legend_handles_labels(), title='smoker')
plt.show()

相关问题