Partykit条件推理树(分类树)中内部节点的显示概率

ohtdti5x  于 2023-02-06  发布在  其他
关注(0)|答案(1)|浏览(163)

在partykit包打印(i. ctree)提供了在终端节点(分类树)的结果的概率。然而,我想知道的结果在内部节点的概率以及。
当我为ctree图创建条形图(i.ctree,inner_panel = node_barplot)时,我可以估计内部节点的概率,但我想要的是内部节点的确切概率,例如,在下图中,我想知道节点2和5的结果概率。
有什么想法吗?
https://cran.r-project.org/web/packages/partykit/vignettes/ctree.pdf为例:
ctree条形图:

0yycz8jy

0yycz8jy1#

有多种方法可以提取属于某个节点的全部数据并计算您感兴趣的任何数量。对于分类树的分布,一种方法是强制到simpleparty类,该类将distribution存储在每个节点的info槽中。
使用您提到的小插图中的示例,首先可以拟合完整的constparty树:

library("partykit")
data("GlaucomaM", package = "TH.data")
gtree <- ctree(Class ~ ., data = GlaucomaM)

然后强制为simpleparty

gtree <- as.simpleparty(gtree)

然后,您可以从每个节点提取分布列表,将其绑定到表中,并计算比例:

tab <- nodeapply(gtree, nodeids(gtree), function(node) node$info$distribution)
tab <- do.call(rbind, tab)
proportions(tab, 1)
##     glaucoma     normal
## 1 0.50000000 0.50000000
## 2 0.86206897 0.13793103
## 3 0.93670886 0.06329114
## 4 0.12500000 0.87500000
## 5 0.21100917 0.78899083
## 6 0.09230769 0.90769231
## 7 0.38636364 0.61363636

您还可以重新使用print.simpleparty中使用的函数,调整面板函数以进行打印:

simpleprint <- function(node) formatinfo_node(node,
  FUN = partykit:::.make_formatinfo_simpleparty(gtree),
  default = "*", prefix = ": ")
print(gtree, inner_panel = simpleprint)
## Model formula:
## Class ~ ag + at + as + an + ai + eag + eat + eas + ean + eai + 
##     abrg + abrt + abrs + abrn + abri + hic + mhcg + mhct + mhcs + 
##     mhcn + mhci + phcg + phct + phcs + phcn + phci + hvc + vbsg + 
##     vbst + vbss + vbsn + vbsi + vasg + vast + vass + vasn + vasi + 
##     vbrg + vbrt + vbrs + vbrn + vbri + varg + vart + vars + varn + 
##     vari + mdg + mdt + mds + mdn + mdi + tmg + tmt + tms + tmn + 
##     tmi + mr + rnf + mdic + emd + mv
## 
## Fitted party:
## [1] root
## |   [2] vari <= 0.059: glaucoma (n = 87, err = 13.8%)
## |   |   [3] vasg <= 0.066: glaucoma (n = 79, err = 6.3%)
## |   |   [4] vasg > 0.066: normal (n = 8, err = 12.5%)
## |   [5] vari > 0.059: normal (n = 109, err = 21.1%)
## |   |   [6] tms <= -0.066: normal (n = 65, err = 9.2%)
## |   |   [7] tms > -0.066: normal (n = 44, err = 38.6%)
## 
## Number of inner nodes:    3
## Number of terminal nodes: 4

相关问题