R语言 未从mongodb正确检索模型

guicsvcw  于 2023-09-27  发布在  Go
关注(0)|答案(1)|浏览(63)

我有一个R脚本,它创建一个模型,序列化它并将其存储在test mongo数据库中的models集合中:

library(mongolite)

mongo_host="localhost"
mongo_port=27017
url_path = sprintf("mongodb://%s:%s", mongo_host, mongo_port)  
mongo_database="test"

mongo_collection <- "models"  
mongo_con<-mongo(collection = mongo_collection
                 ,url = paste0(url_path,"/",mongo_database))

mySerializationFunc<-function(value){
  return (base64enc::base64encode(serialize(value, NULL,refhook = function(x) "dummy value")))
}

myUnserializationFunc<-function(value){
 return (unserialize(value,refhook = function(chr) list(dummy = 0L)))
}

insertDocumentIntoCollection <- function(connection,object) {
  str<-paste0('{"modelName": "',object$modelName,'", "objectModel" :',paste0('{"$binary":{"base64":"',mySerializationFunc(object$objectModel),'","subType": "0"}}}'))
  connection$insert(str)
}

getDocumentFromCollection<-function(connection,modelName){
 
 strConditions=paste0('{"modelName":"',modelName,'"}')
 strSelect=paste0('{"objectModel":true,"_id":false}')
 return(connection$find(query=strConditions,fields=strSelect))
}

modelName<-"irisTestAll"

lst<-list()
lst$modelName<-modelName
lst$objectModel<-randomForest::randomForest(as.formula("Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width"),iris)

# Store the model in mongoDB
insertDocumentIntoCollection(mongo_con,lst)

然后,我可以检索模型,将其解序列化并执行预测:

# Retrieve the model
mdl<-getDocumentFromCollection(mongo_con,modelName)

# By using "mdl[[1]][[1]]" we get allways the first model
mdl<-myUnserializationFunc(mdl[[1]][[1]])

predict(mdl,iris)

现在,我已经创建了shiny版本的创建模型(完全相同的代码):

library(shiny)
library(mongolite)

mongo_host="localhost"
mongo_port=27017
url_path = sprintf("mongodb://%s:%s", mongo_host, mongo_port)  
mongo_database="test"
mongo_collection <- "models"  

mongo_con<-mongo(collection = mongo_collection
                 ,url = paste0(url_path,"/",mongo_database))

mySerializationFunc<-function(value){
  return (base64enc::base64encode(serialize(value, NULL,refhook = function(x) "dummy value")))
}

insertDocumentIntoCollection <- function(connection,object) {
  str<-paste0('{"modelName": "',object$modelName,'", "objectModel" :',paste0('{"$binary":{"base64":"',mySerializationFunc(object$objectModel),'","subType": "0"}}}'))
  connection$insert(str)
}

ui <- fluidPage(
  actionButton("aa","Generate model")
)

server <- function(input, output, session){
  observeEvent(input$aa,{
    lst<-list()
    lst$modelName<-"irisTestAll"
    lst$objectModel<-randomForest::randomForest(as.formula("Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width"),iris)
    
    # Store the model in mongoDB
    insertDocumentIntoCollection(mongo_con,lst)    
    
  })
}

shinyApp(ui, server)

该应用程序似乎工作正常,但是当我从mongo检索模型(使用应用程序存储的模型)来执行预测时:

modelName<-"irisTestAll"
# Retrieve the model
mdl<-getDocumentFromCollection(mongo_con,modelName)

# By using "mdl[[1]][[1]]" we get allways the first model
mdl<-myUnserializationFunc(mdl[[1]][[1]])

predict(mdl,iris)

我得到这个错误:

Error in eval(predvars, data, env) : 
  invalid 'enclos' argument of type 'list'

因此,从R控制台存储似乎可以正常工作,但当使用shiny时会失败。有办法解决吗?谢谢.

enxuqcxy

enxuqcxy1#

问题既不是shiny也不是mongodb,而是序列化/反序列化。
在反序列化randomForest模型对象时,环境包含一个dummy值:

mdl <- getDocumentFromCollection(mongo_con,modelName)
mdl <- myUnserializationFunc(mdl[[1]][[1]])

attr(mdl$terms, ".Environment")
# $dummy
# [1] 0

这导致预测误差:

predict(mdl, newdata=iris)
# Error in eval(predvars, data, env) : 
# invalid 'enclos' argument of type 'list'

让我们正确地替换环境(修改myUnserializationFunc()?),预测将工作良好:

attr(mdl$terms, ".Environment") <- .GlobalEnv

attr(mdl$terms, ".Environment") # check
# <environment: R_GlobalEnv>

# now predict
predict(mdl, newdata=iris)
#        1        2        3        4        5        6        7        8   ...
# 5.102515 4.766670 4.666158 4.804332 5.055100 5.382859 4.891974 5.051596   
# ...

相关问题