局部训练与dataproc训练sparkml模型的不一致性

n53p2ov0  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(481)

我正在将spark从2.3.1版升级到2.4.5版。我正在使用dataproc image 1.4.27-debian9在google云平台的dataproc上重新培训spark2.4.5模型。当我在本地机器上加载dataproc生成的模型时,使用spark2.4.5验证模型。不幸的是,我得到了以下例外:

  1. 20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
  2. 20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
  3. Exception in thread "main" java.lang.IllegalArgumentException: gbtc_961a6ef213b2 parameter impurity given invalid value variance.

加载模型的代码非常简单:

  1. import org.apache.spark.ml.PipelineModel
  2. object ModelLoad {
  3. def main(args: Array[String]): Unit = {
  4. val modelInputPath = getClass.getResource("/model.ml").getPath
  5. val model = PipelineModel.load(modelInputPath)
  6. }
  7. }

我沿着烟囱的轨迹去检查 1_gbtc_961a6ef213b2/metadata/part-00000 模型元数据文件并找到以下内容:

  1. {
  2. "class": "org.apache.spark.ml.classification.GBTClassificationModel",
  3. "timestamp": 1590593177604,
  4. "sparkVersion": "2.4.5",
  5. "uid": "gbtc_961a6ef213b2",
  6. "paramMap": {
  7. "maxIter": 50
  8. },
  9. "defaultParamMap": {
  10. ...
  11. "impurity": "variance",
  12. ...
  13. },
  14. "numFeatures": 1,
  15. "numTrees": 50
  16. }

杂质设置为 variance 但我的本地spark 2.4.5预计 gini . 为了进行合理性检查,我在本地spark2.4.5上重新训练了模型。这个 impurity 模型中的元数据文件设置为 gini .
所以,我检查了gbtjavadoc中的spark2.4.5setinclusion方法。上面写着 The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance." . dataproc使用的spark2.4.5似乎与apachespark文档一致。但是,我从maven central使用的spark 2.4.5设置了 impurity 价值 gini .
有人知道为什么dataproc中的spark2.4.5和maven central之间会有这样的不一致吗?
我创建了一个简单的训练代码来在本地重现结果:

  1. import java.nio.file.Paths
  2. import org.apache.spark.ml.classification.GBTClassifier
  3. import org.apache.spark.ml.feature.VectorAssembler
  4. import org.apache.spark.ml.{Pipeline, PipelineModel}
  5. import org.apache.spark.sql.{DataFrame, SparkSession}
  6. object SimpleModelTraining {
  7. def main(args: Array[String]) {
  8. val currentRelativePath = Paths.get("")
  9. val save_file_location = currentRelativePath.toAbsolutePath.toString
  10. val spark = SparkSession.builder()
  11. .config("spark.driver.host", "127.0.0.1")
  12. .master("local")
  13. .appName("spark-test")
  14. .getOrCreate()
  15. val df: DataFrame = spark.createDataFrame(Seq(
  16. (0, 0),
  17. (1, 0),
  18. (1, 0),
  19. (0, 1),
  20. (0, 1),
  21. (0, 1),
  22. (0, 2),
  23. (0, 2),
  24. (0, 2),
  25. (0, 3),
  26. (0, 3),
  27. (0, 3),
  28. (1, 4),
  29. (1, 4),
  30. (1, 4)
  31. )).toDF("label", "category")
  32. val pipeline: Pipeline = new Pipeline().setStages(Array(
  33. new VectorAssembler().setInputCols(Array("category")).setOutputCol("features"),
  34. new GBTClassifier().setMaxIter(30)
  35. ))
  36. val pipelineModel: PipelineModel = pipeline.fit(df)
  37. pipelineModel.write.overwrite().save(s"$save_file_location/test_model.ml")
  38. }
  39. }

谢谢您!

xa9qqrwz

xa9qqrwz1#

dataproc中的spark为spark-25959提供了一个修复程序,该修复程序可能会导致本地训练的和dataproc训练的ml模型之间的不一致。

相关问题