scala-返回每个组中最大的字符串

zpf6vheq  于 2021-07-09  发布在  Spark
关注(0)|答案(3)|浏览(250)

数据集:

+---+--------+
|age|    name|
+---+--------+
| 33|    Will|
| 26|Jean-Luc|
| 55|    Hugh|
| 40|  Deanna|
| 68|   Quark|
| 59|  Weyoun|
| 37|  Gowron|
| 54|    Will|
| 38|  Jadzia|
| 27|    Hugh|
+---+--------+

这是我的尝试,但它只返回最大字符串的大小,而不是最大字符串的大小:

AgeName.groupBy("age")
      .agg(max(length(AgeName("name")))).show()
g6baxovj

g6baxovj1#

object BasicDatasetTest {

  def main(args: Array[String]): Unit = {
    val spark=SparkSession.builder()
    .master("local[*]")
    .appName("BasicDatasetTest")
    .getOrCreate()

    val pairs=List((33,"Will"),(26,"Jean-Luc"),
    (55,    "Hugh"),
    (26, "Deanna"),
    (26,   "Quark"),
    (55,  "Weyoun"),
    (33,  "Gowron"),
    (55,    "Will"),
    (26,  "Jadzia"),
    (27,   "Hugh"))

    val schema=new StructType(Array(
      StructField("age",IntegerType,false),
      StructField("name",StringType,false))
    )

    val dataRDD=spark.sparkContext.parallelize(pairs).map(record=>Row(record._1,record._2))

    val dataset=spark.createDataFrame(dataRDD,schema)

    val ageNameGroup=dataset.groupBy("age","name")
    .agg(max(length(col("name"))))
    .withColumnRenamed("max(length(name))","length")

    ageNameGroup.printSchema()

    val ageGroup=dataset.groupBy("age")
    .agg(max(length(col("name"))))
    .withColumnRenamed("max(length(name))","length")

    ageGroup.printSchema()

    ageGroup.createOrReplaceTempView("age_group")
    ageNameGroup.createOrReplaceTempView("age_name_group")

    spark.sql("select ag.age,ang.name from age_group as ag, age_name_group as ang " +
      "where ag.age=ang.age and ag.length=ang.length")
    .show()
  }
}
iugsix8n

iugsix8n2#

如果正确指定窗口,通常的行号技巧应该会起作用。以@leoc为例,

val df = Seq(
  (35, "John"),
  (22, "Jennifer"),
  (22, "Alexander"),
  (35, "Michelle"),
  (22, "Celia")
).toDF("age", "name")

val df2 = df.withColumn(
    "rownum", 
    expr("row_number() over (partition by age order by length(name) desc)")
).filter("rownum = 1").drop("rownum")

df2.show
+---+---------+
|age|     name|
+---+---------+
| 22|Alexander|
| 35| Michelle|
+---+---------+
chy5wohz

chy5wohz3#

这里有一种方法是使用spark高阶函数, aggregate ,如下图:

val df = Seq(
  (35, "John"),
  (22, "Jennifer"),
  (22, "Alexander"),
  (35, "Michelle"),
  (22, "Celia")
).toDF("age", "name")

df.
  groupBy("age").agg(collect_list("name").as("names")).
  withColumn(
    "longest_name",
    expr("aggregate(names, '', (acc, x) -> case when length(acc) < length(x) then x else acc end)")
  ).
  show(false)
// +---+----------------------------+------------+
// |age|names                       |longest_name|
// +---+----------------------------+------------+
// |22 |[Jennifer, Alexander, Celia]|Alexander   |
// |35 |[John, Michelle]            |Michelle    |
// +---+----------------------------+------------+

请注意,高阶函数仅适用于spark 2.4+。

相关问题