xgboost预测值都是0,但是rawprediction和probability看起来不错

nhaq1z21  于 2021-05-27  发布在  Spark
关注(0)|答案(0)|浏览(902)

使用xgboost v1.1.1和spark v3.0(可以更改为以前的版本)以及java v8。
我的数据集非常简陋。

Dataset<Row> df = spark.read().format("libsvm")
    .load("data/sample-ml/simplegauss.txt");
df.show(20, false);

给予:

+----------------+--------------+
|label           |features      |
+----------------+--------------+
|1.0             |(1,[0],[1.0]) |
|2.0             |(1,[0],[2.0]) |
|4.0             |(1,[0],[3.0]) |
|8.0             |(1,[0],[4.0]) |
|16.0            |(1,[0],[5.0]) |
|32.0            |(1,[0],[6.0]) |
|32.0            |(1,[0],[7.0]) |
|16.0            |(1,[0],[8.0]) |
|8.0             |(1,[0],[9.0]) |
...

然后我创建我的助推器、模型,并测试我自己的数据:

XGBoostClassifier booster = new XGBoostClassifier()
    .setSilent(1)
    .setNumWorkers(2)
    .setNumRound(100)
    .setMaxDepth(5)
    .setMissing(0f)
    .setLabelCol("label")
    .setFeaturesCol("features");
XGBoostClassificationModel model = booster.fit(df);
Dataset<Row> tDf = model.transform(df);
tDf.show(20, false);

在这一点上,我得到:

+----------------+--------------+--------------------+--------------------+----------+
|label           |features      |rawPrediction       |probability         |prediction|
+----------------+--------------+--------------------+--------------------+----------+
|1.0             |(1,[0],[1.0]) |[1.0009020566940308]|[1.0009020566940308]|0.0       |
|2.0             |(1,[0],[2.0]) |[1.9988372325897217]|[1.9988372325897217]|0.0       |
|4.0             |(1,[0],[3.0]) |[4.001410007476807] |[4.001410007476807] |0.0       |
|8.0             |(1,[0],[4.0]) |[7.999919414520264] |[7.999919414520264] |0.0       |
|16.0            |(1,[0],[5.0]) |[15.997389793395996]|[15.997389793395996]|0.0       |
|32.0            |(1,[0],[6.0]) |[32.0023307800293]  |[32.0023307800293]  |0.0       |
|32.0            |(1,[0],[7.0]) |[32.00194549560547] |[32.00194549560547] |0.0       |
...

当然,稍后如果我使用预测,它会给我0:

Double feature = 2.0;
Vector features = Vectors.dense(feature);
double p = model.predict(features);
System.out.println("Prediction for feature " + feature + " is " + p +
    " (expected: 2)");

退货:

Prediction for feature 2.0 is 0.0 (expected: 2)

我怎样才能在 prediction 列?
我刚刚发现xgboost,所以我可能做错了什么:(。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题