scala 与SparkMap列中的最大值对应的键

u1ehiz5o  于 2023-01-05  发布在  Scala
关注(0)|答案(2)|浏览(185)

如果我有一个从string到double的spark map列,有没有简单的方法用对应于最大值的键生成一个新列?
我可以使用如下所示的集合函数来实现它:

import org.apache.spark.sql.functions._

val mockedDf = Seq(1, 2, 3)
  .toDF("id")
  .withColumn("optimized_probabilities_map", typedLit(Map("foo"->0.34333337, "bar"->0.23)))
val df = mockedDf
  .withColumn("optimizer_probabilities", map_values($"optimized_probabilities_map"))
  .withColumn("max_probability", array_max($"optimizer_probabilities"))
  .withColumn("max_position", array_position($"optimizer_probabilities", $"max_probability"))
  .withColumn("optimizer_ruler_names", map_keys($"optimized_probabilities_map"))
  .withColumn("optimizer_ruler_name", $"optimizer_ruler_names"( $"max_position"))

然而,这个解决方案不必要的长,也不是很有效。还有一个可能的精度问题,因为我在使用array_position时比较双精度数。我想知道是否有一个更好的方法来做这件事,而不使用UDF,也许使用表达式字符串。

6za6bjd0

6za6bjd01#

既然你可以使用Spark 2.4+,一种方法是使用Spark-SQL内置函数aggregate,在这里我们迭代所有map_key,然后将对应的map_values与缓冲值acc.val进行比较,然后相应地更新acc.name

mockedDf.withColumn("optimizer_ruler_name", expr("""
    aggregate(
      map_keys(optimized_probabilities_map), 
      (string(NULL) as name, double(NULL) as val),
      (acc, y) ->
        IF(acc.val is NULL OR acc.val < optimized_probabilities_map[y]
        , (y as name, optimized_probabilities_map[y] as val)
        , acc
        ),
      acc -> acc.name
    )
""")).show(false)
+---+--------------------------------+--------------------+
|id |optimized_probabilities_map     |optimizer_ruler_name|
+---+--------------------------------+--------------------+
|1  |[foo -> 0.34333337, bar -> 0.23]|foo                 |
|2  |[foo -> 0.34333337, bar -> 0.23]|foo                 |
|3  |[foo -> 0.34333337, bar -> 0.23]|foo                 |
+---+--------------------------------+--------------------+
xdnvmnnf

xdnvmnnf2#

另一个解决方案是分解Map列,然后使用Window函数获取最大值,如下所示:

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"id")

val df = mockedDf.select($"id", $"optimized_probabilities_map", explode($"optimized_probabilities_map"))
                 .withColumn("max_value", max($"value").over(w))
                 .where($"max_value" === $"value")
                 .drop("value", "max_value")

相关问题