如何计算回归树R的均方误差?

ehxuflar  于 2022-12-25  发布在  其他
关注(0)|答案(1)|浏览(208)

我正在使用wine quality database
我研究回归树取决于不同的变量为:

library(rpart)
library(rpart.plot)
library(rattle)
library(naniar)
library(dplyr)
library(ggplot2)

vinos <- read.csv(file = 'Wine.csv', header = T)

arbol0<-rpart(formula=quality~chlorides, data=vinos, method="anova")
fancyRpartPlot(arbol0)

arbol1<-rpart(formula=quality~chlorides+density, data=vinos, method="anova")
fancyRpartPlot(arbol1)

我想计算均方误差,看看arbol1是否优于arbol0。由于没有更多数据可用,我将使用自己的数据集。我尝试按

aaa<-predict(object=arbol0, newdata=data.frame(chlorides=vinos$chlorides), type="anova")
bbb<-predict(object=arbol1, newdata=data.frame(chlorides=vinos$chlorides, density=vinos$density), type="anova")

然后手动从aaabbb中减去 Dataframe 的最后一列。但是,我得到了一个错误。有人能帮助我吗?

y4ekin9u

y4ekin9u1#

这个website可能对你有用。在训练你的模型之前,把你的数据集分成训练和测试子集是非常重要的。在下面的代码中,我用base函数完成了这一步。但caTools包中还有一个名为sample.split的函数,它执行相同的过程。我附上了这个website,您可以在其中看到在R中拆分数据的所有方法。
请记住,均方误差(MSE)的函数如下所示:

因此,将其应用于R非常简单。您只需计算观测值(即来自测试子集的响应变量)和预测值(即使用predict函数从模型预测的值)之间的平方差的均值。
葡萄酒数据集的一个解决方案可以是这个,它基于以前的网站。

library(rpart)
library(dplyr)
library(data.table)

vinos <- fread(file = 'Winequality-red.csv', header = TRUE)

# Split data into train and test subsets
sample_index <- sample(nrow(vinos), size = nrow(vinos)*0.75)
train <- vinos[sample_index, ]
test <- vinos[-sample_index, ]

# Train regression trees models
arbol0 <- rpart(formula = quality ~ chlorides, data = train, method = "anova")
arbol1 <- rpart(formula = quality ~ chlorides + density, data = train, method = "anova")

# Make predictions for each model
pred0 <- predict(arbol0, newdata = test)
pred1 <- predict(arbol1, newdata = test)

# Calculate MSE for each model
mean((pred0 - test$quality)^2)
mean((pred1 - test$quality)^2)

相关问题