推断复杂结构字段spark的数据类型的问题

fruv7luv  于 2021-07-13  发布在  Spark
关注(0)|答案(4)|浏览(364)

我有一个SparkDataframe如下。它在zipped\u feature列中有array struct数组。

  1. +--------------------+
  2. |zipped_feature |
  3. +--------------------+
  4. |[[A, 1], [ABC, 33]] |
  5. |[[A, 1], [ABS, 24]] |
  6. |[[B, 2], [ABE, 17]] |
  7. |[[C, 3], [ABC, 33]] |
  8. +--------------------+

我尝试使用index在数组struct的这个数组上获取一个项(也是一个数组)。我试着在自定义项下获取基于索引的值。如果第一行的索引为0,则应检索“[a,1]”作为数组。

  1. val getValueUdf = udf { (zippedFeature: Seq[Seq[String]], index: Int) => zippedFeature(index) }

但我的错误率越来越低

  1. data type mismatch: argument 1 requires array<array<string>> type, however, '`zipped_feature`' is of array<struct<_1:string,_2:string>> type.

当我打印模式时,它显示如下

  1. |-- zipped_feature: array (nullable = true)
  2. | |-- element: struct (containsNull = true)
  3. | | |-- _1: string (nullable = true)
  4. | | |-- _2: string (nullable = true)

有人能帮我找出我做错了什么吗。我想得到基于索引的值(同样是数组)。

roejwanj

roejwanj1#

zipped\u feature是array类型的列。如果要将每个嵌套列值作为一个数组获取,则需要修改udf,如下所示。

  1. val spark = SparkSession.builder().master("local[*]").getOrCreate()
  2. spark.sparkContext.setLogLevel("OFF")
  3. import spark.implicits._
  4. import org.apache.spark.sql.functions._
  5. import org.apache.spark.sql.types._
  6. // constructing sample dataframe
  7. val rows=
  8. List(Row(Array(Row("A","1"),Row("ABC","33"))),
  9. Row(Array(Row("A","1"),Row("ABS","24"))),
  10. Row(Array(Row("B","2"),Row("ABE","17"))),
  11. Row(Array(Row("C","3"),Row("ABC","33"))))
  12. val rdd=spark.sparkContext.parallelize(rows)
  13. val schema=new StructType().add("zipped_feature",ArrayType(new StructType().add("_1",StringType).add("_2",StringType)))
  14. val df=spark.createDataFrame(rdd,schema)
  15. df.show()
  16. /*
  17. +-------------------+
  18. | zipped_feature|
  19. +-------------------+
  20. |[[A, 1], [ABC, 33]]|
  21. |[[A, 1], [ABS, 24]]|
  22. |[[B, 2], [ABE, 17]]|
  23. |[[C, 3], [ABC, 33]]|
  24. +-------------------+
  25. * /
  26. df.printSchema()
  27. /*
  28. root
  29. |-- zipped_feature: array (nullable = true)
  30. | |-- element: struct (containsNull = true)
  31. | | |-- _1: string (nullable = true)
  32. | | |-- _2: string (nullable = true)
  33. * /
  34. // udf
  35. val getValueUdf = udf { (zippedFeature: Seq[Row],index:Int) =>zippedFeature(index).toSeq.map(_.toString)}
  36. df.withColumn("first_column",getValueUdf('zipped_feature,lit(0)))
  37. .withColumn("second_column",getValueUdf('zipped_feature,lit(1)))
  38. .show(false)
  39. /* output
  40. +-------------------+------------+-------------+
  41. |zipped_feature |first_column|second_column|
  42. +-------------------+------------+-------------+
  43. |[[A, 1], [ABC, 33]]|[A, 1] |[ABC, 33] |
  44. |[[A, 1], [ABS, 24]]|[A, 1] |[ABS, 24] |
  45. |[[B, 2], [ABE, 17]]|[B, 2] |[ABE, 17] |
  46. |[[C, 3], [ABC, 33]]|[C, 3] |[ABC, 33] |
  47. +-------------------+------------+-------------+
  48. * /
展开查看全部
ugmeyewa

ugmeyewa2#

根据我的说法,这个用例不需要用户定义的函数。你可以很容易地使用 withColumn 以及 select 完成任务的声明。

  1. //Source data
  2. import org.apache.spark.sql.functions._
  3. import org.apache.spark.sql.types._
  4. import spark.implicits._
  5. val df = Seq((Seq(Array("A","1"),Array("ABC","33"))),(Seq(Array("A","1"),Array("ABS","24")))).toDF("zipped_feature")
  6. // 1) getting the value using select statements
  7. val df1 = df.select($"zipped_feature"(0).as("ArrayZero"),$"zipped_feature"(1).as("ArrayOne"))
  8. // 2) getting the values using withColumn
  9. val df2 = df.withColumn("Array_Zero",$"zipped_feature"(0)).withColumn("Array_One",$"zipped_feature"(1))
  10. // 3) Getting the value of the Inner array
  11. val df3 = df1.select($"ArrayZero"(0).as("InnerArrayZero"))
  12. // 4) Getting the value of the first element
  13. val value = df1.select($"ArrayZero"(0)).first.getString(0)

输出1:

输出2:

输出3:

输出4:

展开查看全部
tp5buhyn

tp5buhyn3#

从错误消息中,列 zipped_feature 是结构数组类型,而不是数组类型。不需要自定义项按索引访问数组元素,可以使用以下选项之一:

  1. col("zipped_feature")(idx) // opt1
  2. col("zipped_feature").getItem(idx) // opt2
  3. element_at(col("zipped_feature"), idx) // opt3

要将结构数组转换为数组数组,可以使用 transform 功能:

  1. val df1 = df.withColumn(
  2. "zipped_feature",
  3. expr("transform(zipped_feature, x -> array(x._1, x._2))")
  4. ).select(
  5. col("zipped_feature")(0).as("idx0"),
  6. col("zipped_feature")(1).as("idx1")
  7. )
  8. df1.show
  9. //+------+---------+
  10. //| idx0| idx1|
  11. //+------+---------+
  12. //|[A, 1]|[ABC, 33]|
  13. //|[A, 1]|[ABS, 24]|
  14. //|[B, 2]|[ABE, 17]|
  15. //|[C, 3]|[ABC, 33]|
  16. //+------+---------+
  17. df1.printSchema
  18. //root
  19. // |-- idx0: array (nullable = true)
  20. // | |-- element: string (containsNull = true)
  21. // |-- idx1: array (nullable = true)
  22. // | |-- element: string (containsNull = true)

或者直接不变换数组:

  1. val df1 = df.select(
  2. expr("array(zipped_feature[0]._1, zipped_feature[0]._2)").as("idx0"),
  3. expr("array(zipped_feature[1]._1, zipped_feature[1]._2)").as("idx1")
  4. )
展开查看全部
hmtdttj4

hmtdttj44#

您可以尝试使用dataset api map 方法:

  1. def getValue(zippedFeature: Seq[(String, String)], index: Int): Seq[String] = {
  2. zippedFeature(index).productIterator.toList.toSeq.map(_.toString)
  3. }
  4. df.as[Seq[(String, String)]].map(x => (x, getValue(x, 0))).show
  5. +-------------------+------+
  6. | _1| _2|
  7. +-------------------+------+
  8. |[[A, 1], [ABC, 33]]|[A, 1]|
  9. |[[A, 1], [ABS, 24]]|[A, 1]|
  10. |[[B, 2], [ABE, 17]]|[B, 2]|
  11. |[[C, 3], [ABC, 33]]|[C, 3]|
  12. +-------------------+------+

相关问题