xgboost预测值都是0,但是rawprediction和probability看起来不错

nhaq1z21  于 2021-05-27  发布在  Spark
关注(0)|答案(0)|浏览(930)

使用xgboost v1.1.1和spark v3.0(可以更改为以前的版本)以及java v8。
我的数据集非常简陋。

  1. Dataset<Row> df = spark.read().format("libsvm")
  2. .load("data/sample-ml/simplegauss.txt");
  3. df.show(20, false);

给予:

  1. +----------------+--------------+
  2. |label |features |
  3. +----------------+--------------+
  4. |1.0 |(1,[0],[1.0]) |
  5. |2.0 |(1,[0],[2.0]) |
  6. |4.0 |(1,[0],[3.0]) |
  7. |8.0 |(1,[0],[4.0]) |
  8. |16.0 |(1,[0],[5.0]) |
  9. |32.0 |(1,[0],[6.0]) |
  10. |32.0 |(1,[0],[7.0]) |
  11. |16.0 |(1,[0],[8.0]) |
  12. |8.0 |(1,[0],[9.0]) |
  13. ...

然后我创建我的助推器、模型,并测试我自己的数据:

  1. XGBoostClassifier booster = new XGBoostClassifier()
  2. .setSilent(1)
  3. .setNumWorkers(2)
  4. .setNumRound(100)
  5. .setMaxDepth(5)
  6. .setMissing(0f)
  7. .setLabelCol("label")
  8. .setFeaturesCol("features");
  9. XGBoostClassificationModel model = booster.fit(df);
  10. Dataset<Row> tDf = model.transform(df);
  11. tDf.show(20, false);

在这一点上,我得到:

  1. +----------------+--------------+--------------------+--------------------+----------+
  2. |label |features |rawPrediction |probability |prediction|
  3. +----------------+--------------+--------------------+--------------------+----------+
  4. |1.0 |(1,[0],[1.0]) |[1.0009020566940308]|[1.0009020566940308]|0.0 |
  5. |2.0 |(1,[0],[2.0]) |[1.9988372325897217]|[1.9988372325897217]|0.0 |
  6. |4.0 |(1,[0],[3.0]) |[4.001410007476807] |[4.001410007476807] |0.0 |
  7. |8.0 |(1,[0],[4.0]) |[7.999919414520264] |[7.999919414520264] |0.0 |
  8. |16.0 |(1,[0],[5.0]) |[15.997389793395996]|[15.997389793395996]|0.0 |
  9. |32.0 |(1,[0],[6.0]) |[32.0023307800293] |[32.0023307800293] |0.0 |
  10. |32.0 |(1,[0],[7.0]) |[32.00194549560547] |[32.00194549560547] |0.0 |
  11. ...

当然,稍后如果我使用预测,它会给我0:

  1. Double feature = 2.0;
  2. Vector features = Vectors.dense(feature);
  3. double p = model.predict(features);
  4. System.out.println("Prediction for feature " + feature + " is " + p +
  5. " (expected: 2)");

退货:

  1. Prediction for feature 2.0 is 0.0 (expected: 2)

我怎样才能在 prediction 列?
我刚刚发现xgboost,所以我可能做错了什么:(。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题