我有一个代码:
pipeline = Pipeline(stages=[indexer, encoder, categorical_assembler,numerical_assembler,standardScaler,final_assembler,regressor1])
train_data, test_data = df.randomSplit([.8,.2], seed=RANDOM_SEED)
model1 = pipeline.fit(train_data)
pred_results = model1.transform(test_data)
results = pred_results.select('prediction', 'median_house_value')
results = results.withColumn('prediction', F.round('prediction'))
display(results.show(5))
我工作很好。然后我想估计模型的指标:
evaluator = RegressionEvaluator(predictionCol="prediction", labelCol="median_house_value")
r2 = evaluator.evaluate(results, {evaluator.metricName: "r2"})
r2
它每次都会出现这样的错误:
Py4JJavaError Traceback (most recent call last) Cell In[71], line 6 3 display(HTML(test_data.toPandas().info())) 4 #display(HTML(results.toPandas().info())) 5 # Расчет метрики R2 6 r2 = evaluator.evaluate(results, {evaluator.metricName: "r2"}) 7 r2
File ~\AppData\Roaming\Python\Python310\site-packages\pyspark\ml\evaluation.py:109, in Evaluator.evaluate(self, dataset, params) 107 if isinstance(params, dict): 108 if params: 109 return self.copy(params)._evaluate(dataset) 110 else: 111 return self._evaluate(dataset)
File ~\AppData\Roaming\Python\Python310\site-packages\pyspark\ml\evaluation.py:148, in JavaEvaluator._evaluate(self, dataset) 146 self._transfer_params_to_java() 147 assert self._java_obj is not None 148 return self._java_obj.evaluate(dataset._jdf)
File ~\anaconda3\lib\site-packages\py4j\java_gateway.py:1322, in JavaMember.__call__(self, *args) 1316 command = proto.CALL_COMMAND_NAME +\ 1317 self.command_header +\ 1318 args_command +\ 1319 proto.END_COMMAND_PART 1321 answer = self.gateway_client.send_command(command) 1322 return_value = get_return_value( 1323 answer, self.gateway_client, self.target_id, self.name) 1325 for temp_arg in temp_args: 1326 if hasattr(temp_arg, "_detach"):
File ~\AppData\Roaming\Python\Python310\site-packages\pyspark\errors\exceptions\captured.py:169, in capture_sql_exception.<locals>.deco(*a, **kw) 167 def deco(*a: Any, **kw: Any) -> Any: 168 try: 169 return f(*a, **kw) 170 except Py4JJavaError as e: 171 converted = convert_exception(e.java_exception)
File ~\anaconda3\lib\site-packages\py4j\protocol.py:326, in get_return_value(answer, gateway_client, target_id, name) 324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client) 325 if answer[1] == REFERENCE_TYPE: 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError( 331 "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". 332 format(target_id, ".", name, value))
Py4JJavaError: An error occurred while calling o2827.evaluate. org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 102.0 failed 1 times, most recent failure: Lost task 0.0 in stage 102.0 (TID 85) (DESKTOP-9JCLV0J executor driver): org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (StringIndexerModel$$Lambda$3713/2049034039: (string) => double). at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:217) at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator.foreach(Iterator.scala:943) at scala.collection.Iterator.foreach$(Iterator.scala:943) at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199) at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192) at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431) at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260) at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260) at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431) at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$4(RDD.scala:1234) at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$6(RDD.scala:1235) at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:853) at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:853) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364) at org.apache.spark.rdd.RDD.iterator(RDD.scala:328) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92) at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161) at org.apache.spark.scheduler.Task.run(Task.scala:139) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557) at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source) at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source) at java.lang.Thread.run(Unknown Source) Caused by: org.apache.spark.SparkException: Unseen label: ISLAND. To handle unseen labels, set Param handleInvalid to keep. at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1(StringIndexer.scala:406) at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1$adapted(StringIndexer.scala:391) 32 more
Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206) at scala.Option.foreach(Option.scala:407) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2358) at org.apache.spark.rdd.RDD.$anonfun$fold$1(RDD.scala:1172) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:405) at org.apache.spark.rdd.RDD.fold(RDD.scala:1166) at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$2(RDD.scala:1259) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:405) at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1226) at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$1(RDD.scala:1212) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:405) at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1212) at org.apache.spark.mllib.stat.Statistics$.colStats(Statistics.scala:58) at org.apache.spark.mllib.evaluation.RegressionMetrics.summary$lzycompute(RegressionMetrics.scala:70) at org.apache.spark.mllib.evaluation.RegressionMetrics.summary(RegressionMetrics.scala:62) at org.apache.spark.mllib.evaluation.RegressionMetrics.SSerr$lzycompute(RegressionMetrics.scala:74) at org.apache.spark.mllib.evaluation.RegressionMetrics.SSerr(RegressionMetrics.scala:74) at org.apache.spark.mllib.evaluation.RegressionMetrics.r2(RegressionMetrics.scala:131) at org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate(RegressionEvaluator.scala:102) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source) at java.lang.reflect.Method.invoke(Unknown Source) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182) at py4j.ClientServerConnection.run(ClientServerConnection.java:106) at java.lang.Thread.run(Unknown Source) Caused by: org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (StringIndexerModel$$Lambda$3713/2049034039: (string) => double). at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:217) at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator.foreach(Iterator.scala:943) at scala.collection.Iterator.foreach$(Iterator.scala:943) at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199) at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192) at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431) at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260) at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260) at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431) at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$4(RDD.scala:1234) at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$6(RDD.scala:1235) at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:853) at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:853) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364) at org.apache.spark.rdd.RDD.iterator(RDD.scala:328) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92) at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161) at org.apache.spark.scheduler.Task.run(Task.scala:139) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557) at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source) at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source) 1 more Caused by: org.apache.spark.SparkException: Unseen label: ISLAND. To handle unseen labels, set Param handleInvalid to keep. at
org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1(StringIndexer.scala:406)
at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1$adapted(StringIndexer.scala:391)
... 32 more
我已经测试了这个代码的小样本数据集和它的工作。我希望得到这样的东西:
0.6152
更奇怪的是,当我跑步的时候:
evaluator = RegressionEvaluator(predictionCol="median_income", labelCol="median_house_value")
display(HTML(test_data.toPandas().info()))
r2 = evaluator.evaluate(test_data, {evaluator.metricName: "r2"})
r2
它的工作没有任何问题!
1条答案
按热度按时间wtzytmuj1#
我设法通过添加“.setHandleInvalid(“keep”)"解决了这个问题。所以工作代码看起来像这样: