在本地计算机中执行时,尝试使用spark和scala创建决策树时出现getting task not serializable错误

k5hmc34c  于 2021-05-17  发布在  Spark
关注(0)|答案(0)|浏览(379)

我试图创建欺诈交易探测器使用Spark与scala。我的代码在正常的spark逻辑下运行良好。然而,当我尝试使用决策树方法的解决方案时,我得到了task not serializable错误。据我所知,当我尝试将我的训练数据拟合到管道中时,会出现这个错误。我尝试了一些解决方案,比如扩展到可序列化、将索引数据转换回字符串等,但都没有成功。
有人能帮我理解我做错了什么吗
下面是我的代码

package com.vinspark.frauddetection

     object FraudDet_ML extends Serializable{

     def main(args:Array[String]){

     Logger.getLogger("org").setLevel(Level.ERROR)

     val spark=SparkSession
      .builder()
      .appName("FraudDetection")
      .master("local[*]")
      .config("spark.sql.warehouse.dir","file:///C:/temp")
      .getOrCreate()

     import spark.sqlContext.implicits._

     var df=spark.read.format("csv").option("header", "true").option("mode", 
     "DROPMALFORMED").option("inferSchema", "true").load("../PS_20174392719_1491204439457_log.csv")

     df= df.withColumn("orgDiff", col("newbalanceOrig") - 
      col("oldbalanceOrg")).withColumn("destDiff", 
      col("newbalanceDest") - col("oldbalanceDest"))

     df=   df.withColumn("label",
       when((col("oldbalanceOrg") <=56900 && col("type")=="TRANSFER" && col("newbalanceDest") <= 105)
       ||(col("oldbalanceOrg") >56900 && col("newbalanceOrig")<=12)
       ||(col("oldbalanceOrg") >56900 && col("newbalanceOrig")>12 && col("amount")>1160000),
     1)
     .otherwise(0) 
     )

     df.createOrReplaceTempView("transaction")

     val indexer = new StringIndexer().setInputCol("type").setOutputCol("typeIndexed")   
     println("indexer created")

     val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel").fit(df)

     val splitDataSet: Array[org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]] 
       =df.randomSplit(Array(0.8, 0.2), 12345L)
     val trainDataFrame = splitDataSet(0)val testDataFrame = splitDataSet(1)
     val train = splitDataSet(0)
     val test = splitDataSet(1)
     println("train: "+train.count()+" test: "+test.count())

     val va = new VectorAssembler().setInputCols(Array("typeIndexed", "amount", "oldbalanceOrg", 
      "newbalanceOrig", "oldbalanceDest", "newbalanceDest", "orgDiff", 
      "destDiff")).setOutputCol("features")

     println("vector assembler created")    
     val dt = new DecisionTreeClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features").setSeed(54321).setMaxDepth(5)

      println("decision tree created")

      val labelConverter = new IndexToString()
        .setInputCol("prediction")
        .setOutputCol("predictedLabel")
        .setLabels(labelIndexer.labels)

       val pipeline = new Pipeline().setStages(Array(indexer,labelIndexer, va, dt))  
       println("pipeline created")   
       val pipeConst=pipeline  
       val model = pipeConst.fit(train)  
       val model1 = model  
       println("model created")
       val prediction=model1.transform(test)
       println("indexer created")
       prediction.collect().foreach(println)
      }
      }

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题