设置pandas中单元格组的格式

k7fdbhmy  于 2023-04-10  发布在  其他
关注(0)|答案(1)|浏览(130)

我的Pandas Dataframe 如下所示

import pandas as pd

inp_df = pd.DataFrame(
  [
    ["a1", "b1", "c1", "gbt",   "auc",  82.5,  80.1,  83.6],
    ["a1", "b1", "c1", "gbt", "pr@5%",   0.3,   0.2,   0.4],
    ["a1", "b1", "c1", "gbt", "re@5%",  60.2,  58.1,  61.3],
    ["a1", "b1", "c1", "rnn",   "auc",  84.1,  83.8,  84.5],
    ["a1", "b1", "c1", "rnn", "pr@5%",   0.5,   0.4,   0.6],
    ["a1", "b1", "c1", "rnn", "re@5%",  61.5,  61.4,  61.7],
    ["a1", "b1", "c1", "llm",   "auc",  84.3,  84.1,  84.6],
    ["a1", "b1", "c1", "llm", "pr@5%",   0.8,   0.7,   0.9],
    ["a1", "b1", "c1", "llm", "re@5%",  61.2,  61.1,  61.3],
    ["a1", "b1", "c2", "gbt",   "auc",  82.5,  80.1,  83.6],
    ["a1", "b1", "c2", "gbt", "pr@5%",   0.3,   0.2,   0.4],
    ["a1", "b1", "c2", "gbt", "re@5%",  60.2,  58.1,  61.3],
    ["a1", "b1", "c2", "llm",   "auc",  84.3,  84.1,  84.6],
    ["a1", "b1", "c2", "llm", "pr@5%",   0.8,   0.7,   0.9],
    ["a1", "b1", "c2", "llm", "re@5%",  61.2,  61.1,  61.3],
  ], columns=["A","B","C","model","metric","val","val_lo","val_hi"])

我想显示如下所示x1c 0d1x
备注:
1.对于每个metric(例如auc),使用粗体表示val最高的型号
1.突出显示具有重叠(val_lo,val_hi)所有模型(在(A,B,C)内)的单元格,重叠(瓦尔_lo,val_hi)是置信区间。
1.在每组模型后画一条线
跟随这个观点可能更容易

import itertools as it

inp_df["model"] = pd.Categorical(inp_df["model"], 
                           ["gbt","rnn","llm"], ordered=True)
cols = list(it.product(["auc","pr@5%","re@5%"],["val","val_lo","val_hi"]))
inp_df.pivot(index=inp_df.columns[:4], columns="metric", values=inp_df.columns[-3:])\
  .swaplevel(0,1,axis=1).reindex(pd.MultiIndex.from_tuples(cols), axis=1)

使用df.apply(func)和合适的groupby可以很容易地识别每个指标的最大值,并查看哪些行与该行重叠。但我不知道如何将其格式化为如上所示!

oxosxuxt

oxosxuxt1#

要遵循和理解脚本步骤,只需printrr2r3

建议稿

import pandas as pd

df = pd.DataFrame(
  [
    ["a1", "b1", "c1", "gbt",   "auc",  82.5,  80.1,  83.6],
    ["a1", "b1", "c1", "gbt", "pr@5%",   0.3,   0.2,   0.4],
    ["a1", "b1", "c1", "gbt", "re@5%",  60.2,  58.1,  61.3],
    ["a1", "b1", "c1", "rnn",   "auc",  84.1,  83.8,  84.5],
    ["a1", "b1", "c1", "rnn", "pr@5%",   0.5,   0.4,   0.6],
    ["a1", "b1", "c1", "rnn", "re@5%",  61.5,  61.4,  61.7],
    ["a1", "b1", "c1", "llm",   "auc",  84.3,  84.1,  84.6],
    ["a1", "b1", "c1", "llm", "pr@5%",   0.8,   0.7,   0.9],
    ["a1", "b1", "c1", "llm", "re@5%",  61.2,  61.1,  61.3],
    ["a1", "b1", "c2", "gbt",   "auc",  82.5,  80.1,  83.6],
    ["a1", "b1", "c2", "gbt", "pr@5%",   0.3,   0.2,   0.4],
    ["a1", "b1", "c2", "gbt", "re@5%",  60.2,  58.1,  61.3],
    ["a1", "b1", "c2", "llm",   "auc",  84.3,  84.1,  84.6],
    ["a1", "b1", "c2", "llm", "pr@5%",   0.8,   0.7,   0.9],
    ["a1", "b1", "c2", "llm", "re@5%",  61.2,  61.1,  61.3],
  ], columns=["A","B","C","model","metric","val","val_lo","val_hi"])

r = pd.melt(df, id_vars=['A','B','C','model',"metric"], \
            value_vars=["val","val_lo","val_hi"])

def func(g):
    val = g.iloc[0, 6]
    val_lo = g.iloc[1, 6]
    val_hi = g.iloc[2, 6] 
    g['val'] = "%s(%s-%s)"%(val, val_lo, val_hi)
    return(g)

r2 = (r.groupby(['A','B','C','model', 'metric'])
       .apply(lambda g: func(g))
       ).drop(labels=['variable', 'value'], axis=1)

r3 =( (r2.groupby(['A','B','C','model'])
        .apply(lambda g: g.pivot(columns="metric", values='val'))
        .reset_index(2)
        .bfill().ffill()
        )
      .drop_duplicates()
      .droplevel(3)
      .reset_index()
      )

r3.columns.name = ''

print(r3)

结果

A   B model   C              auc         pr@5%            re@5%
0   a1  b1   gbt  c1  82.5(80.1-83.6)  0.3(0.2-0.4)  60.2(58.1-61.3)
1   a1  b1   gbt  c1  84.3(84.1-84.6)  0.3(0.2-0.4)  60.2(58.1-61.3)
2   a1  b1   gbt  c1  84.3(84.1-84.6)  0.8(0.7-0.9)  60.2(58.1-61.3)
3   a1  b1   llm  c1  84.3(84.1-84.6)  0.8(0.7-0.9)  61.2(61.1-61.3)
4   a1  b1   llm  c1  84.1(83.8-84.5)  0.8(0.7-0.9)  61.2(61.1-61.3)
5   a1  b1   llm  c1  84.1(83.8-84.5)  0.5(0.4-0.6)  61.2(61.1-61.3)
6   a1  b1   rnn  c1  84.1(83.8-84.5)  0.5(0.4-0.6)  61.5(61.4-61.7)
7   a1  b1   rnn  c1  82.5(80.1-83.6)  0.5(0.4-0.6)  61.5(61.4-61.7)
8   a1  b1   rnn  c1  82.5(80.1-83.6)  0.3(0.2-0.4)  61.5(61.4-61.7)
9   a1  b1   gbt  c2  82.5(80.1-83.6)  0.3(0.2-0.4)  60.2(58.1-61.3)
10  a1  b1   gbt  c2  84.3(84.1-84.6)  0.3(0.2-0.4)  60.2(58.1-61.3)
11  a1  b1   gbt  c2  84.3(84.1-84.6)  0.8(0.7-0.9)  60.2(58.1-61.3)
12  a1  b1   llm  c2  84.3(84.1-84.6)  0.8(0.7-0.9)  61.2(61.1-61.3)

相关问题