R语言 如何在机器学习中计算日志丢失

aiqt4smr  于 2022-12-06  发布在  其他
关注(0)|答案(2)|浏览(112)

以下代码用于生成随机森林二进制分类的概率输出。

library(randomForest) 

rf <- randomForest(train, train_label,importance=TRUE,proximity=TRUE)
prediction<-predict(rf, test, type="prob")

那么关于预测的结果如下:

关于测试数据的真标签是已知的(命名为test_label)。现在我想计算二元分类的概率输出logarithmic loss。关于LogLoss的函数如下。

LogLoss=function(actual, predicted)
{
  result=-1/length(actual)*(sum((actual*log(predicted)+(1-actual)*log(1-predicted))))
  return(result)
}

如何用二分类的概率输出计算对数损失。谢谢。

emeijp43

emeijp431#

library(randomForest) 

rf <- randomForest(Species~., data = iris, importance=TRUE, proximity=TRUE)
prediction <- predict(rf, iris, type="prob")
#bound the results, otherwise you might get infinity results
prediction <- apply(prediction, c(1,2), function(x) min(max(x, 1E-15), 1-1E-15)) 

#model.matrix generates a true probabilities matrix, where an element is either 1 or 0
#we subtract the prediction, and, if the result is bigger than 0 that's the correct class
logLoss = function(pred, actual){
  -1*mean(log(pred[model.matrix(~ actual + 0) - pred > 0]))
}

logLoss(prediction, iris$Species)
kqlmhetl

kqlmhetl2#

我认为logLoss公式有点错误。

model <- glm(vs ~ mpg, data = mtcars, family = "binomial")

### OP's formula (Wrong)
logLoss1 <- function(pred, actual){
  -1*mean(log(pred[model.matrix(~ actual + 0) - pred > 0]))
}
logLoss1(actual = model$y, pred = model$fitted.values)
# [1] 0.4466049

### Correct formula in native R 
logLoss2 <- function(pred, actual){
  -mean(actual * log(pred) + (1 - actual) * log(1 - pred))
}
logLoss2(actual = model$y, pred = model$fitted.values)
# [1] 0.3989584

## Results from various packages to verify the correct answer

### From ModelMetrics package
ModelMetrics::logLoss(actual = model$y, pred = model$fitted.values)
# [1] 0.3989584

### From MLmetrics package
MLmetrics::LogLoss(y_pred = model$fitted.values, y_true = model$y)
# [1] 0.3989584

### From reticulate package
sklearn.metrics <- import("sklearn.metrics")
sklearn.metrics$log_loss(y_true = model$y, y_pred = model$fitted.values)
# [1] 0.3989584

我使用了R版本4.1.0(2021年5月18日)。

相关问题