scala spark udf classcastexception:wrappedarray$ofref不能转换为[lscala.tuple2

r6vfmomb  于 2021-07-13  发布在  Spark
关注(0)|答案(1)|浏览(414)

所以我执行必要的导入等

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import spark.implicits._

然后定义一些长点

val london = (1.0, 1.0)
val suburbia = (2.0, 2.0)
val southampton = (3.0, 3.0)  
val york = (4.0, 4.0)

然后我创建一个像这样的sparkDataframe并检查它是否工作:

val exampleDF = Seq((List(london,suburbia),List(southampton,york)),
    (List(york,london),List(southampton,suburbia))).toDF("AR1","AR2")
exampleDF.show()

Dataframe由以下类型组成 DataFrame = [AR1: array<struct<_1:double,_2:double>>, AR2: array<struct<_1:double,_2:double>>] 我创建一个函数来创建点的组合

// function to do what I want
val latlongexplode =  (x: Array[(Double,Double)], y: Array[(Double,Double)]) => {
 for (a <- x; b <-y) yield (a,b)
}

我检查功能是否正常

latlongexplode(Array(london,york),Array(suburbia,southampton))

确实如此。但是在我用这个函数创建了一个自定义项之后

// declare function into a Spark UDF
val latlongexplodeUDF = udf (latlongexplode)

当我尝试在spark数据框中使用它时,我在上面创建了如下:

exampleDF.withColumn("latlongexplode", latlongexplodeUDF($"AR1",$"AR2")).show(false)

我得到一个很长的堆栈跟踪,基本上可以归结为:
java.lang.classcastexception:scala.collection.mutable.wrappedarray$ofref不能强制转换为[lscala.tuple2;
org.apache.spark.sql.catalyst.expressions.scalaudf.$anonfun$f$3(scalaudf。scala:121)org.apache.spark.sql.catalyst.expressions.scalaudf.eval(scalaudf。scala:1063)org.apache.spark.sql.catalyst.expressions.alias.eval(namedexpressions。scala:151)org.apache.spark.sql.catalyst.expressions.interpretatedproject.apply(投影。scala:50) org.apache.spark.sql.catalyst.expressions.interpretatedproject.apply(投影。scala:32)scala.collection.traversablelike.$anonfun$map$1(traversablelike。scala:273)
如何让这个自定义项在scala spark中工作(我现在正在使用2.4(如果这有帮助的话)
编辑:可能是我构造示例df的方式有问题。但我所拥有的实际数据是每列上的lat/long元组数组(大小未知)。

9udxz4iz

9udxz4iz1#

在udf中使用结构类型时,它们表示为行对象,数组列表示为seq。此外,还需要以行的形式返回结构,并且需要定义一个模式来返回结构。

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

val london = (1.0, 1.0)
val suburbia = (2.0, 2.0)
val southampton = (3.0, 3.0)  
val york = (4.0, 4.0)
val exampleDF = Seq((List(london,suburbia),List(southampton,york)),
    (List(york,london),List(southampton,suburbia))).toDF("AR1","AR2")
exampleDF.show(false)
+------------------------+------------------------+
|AR1                     |AR2                     |
+------------------------+------------------------+
|[[1.0, 1.0], [2.0, 2.0]]|[[3.0, 3.0], [4.0, 4.0]]|
|[[4.0, 4.0], [1.0, 1.0]]|[[3.0, 3.0], [2.0, 2.0]]|
+------------------------+------------------------+
val latlongexplode = (x: Seq[Row], y: Seq[Row]) => {
    for (a <- x; b <- y) yield Row(a, b)
}

val udf_schema = ArrayType(
    StructType(Seq(
        StructField(
            "city1",
            StructType(Seq(
                StructField("lat", FloatType),
                StructField("long", FloatType)
            ))
        ),
        StructField(
            "city2",
            StructType(Seq(
                StructField("lat", FloatType),
                StructField("long", FloatType)
            ))
        )
    ))
)

// include this line if you see errors like 
// "You're using untyped Scala UDF, which does not have the input type information."
// spark.sql("set spark.sql.legacy.allowUntypedScalaUDF = true")

val latlongexplodeUDF = udf(latlongexplode, udf_schema)
result = exampleDF.withColumn("latlongexplode", latlongexplodeUDF($"AR1",$"AR2"))
result.show(false)
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+
|AR1                     |AR2                     |latlongexplode                                                                                          |
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+
|[[1.0, 1.0], [2.0, 2.0]]|[[3.0, 3.0], [4.0, 4.0]]|[[[1.0, 1.0], [3.0, 3.0]], [[1.0, 1.0], [4.0, 4.0]], [[2.0, 2.0], [3.0, 3.0]], [[2.0, 2.0], [4.0, 4.0]]]|
|[[4.0, 4.0], [1.0, 1.0]]|[[3.0, 3.0], [2.0, 2.0]]|[[[4.0, 4.0], [3.0, 3.0]], [[4.0, 4.0], [2.0, 2.0]], [[1.0, 1.0], [3.0, 3.0]], [[1.0, 1.0], [2.0, 2.0]]]|
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+

相关问题