如何在窗口操作后收集数据?groupby搞乱了顺序,我得到了错误的结果

sd2nnvve  于 2021-05-27  发布在  Spark
关注(0)|答案(2)|浏览(394)

我有这个密码

groupedDF.show()
val window =
      Window.partitionBy($"app_id", $"country_code").orderBy($"rate".desc)

    val windowResult = groupedDF
      .transform(calculateRankOverWindow(window))
      .limit(topN) //change this to .where("rank<=topN") and it works.

    windowResult.show()
    val finalResult = windowResult
      .groupBy("app_id", "country_code")
      .agg(collect_list("advertiser_id").as("recommended_advertiser_ids"))

    finalResult.show()

仅一个应用程序id和国家/地区的示例输出。因为有许多不同的应用程序ID和国家。
地面DF

+------+------------+-------------+-----------------+
|app_id|country_code|advertiser_id|             rate|
+------+------------+-------------+-----------------+
|    32|          UK|            9|              8.0|
|    32|          UK|            5|              5.5|
|    32|          UK|            4|              5.5|
|    32|          UK|            6|              6.1|
|    32|          UK|            3|              5.5|
|    32|          UK|            2|              2.0|
|    32|          UK|            1|6.099999999999999|
+------+------------+-------------+-----------------+

窗口结果

+------+------------+-------------+-----------------+----+
|app_id|country_code|advertiser_id|             rate|rank|
+------+------------+-------------+-----------------+----+
|    32|          UK|            9|              8.0|   1|
|    32|          UK|            6|              6.1|   2|
|    32|          UK|            1|6.099999999999999|   3|
|    32|          UK|            5|              5.5|   4|
|    32|          UK|            4|              5.5|   4|
+------+------------+-------------+-----------------+----+

所需输出格式每个应用程序id和国家代码的广告客户id列表。
不正确(我得到的)

+------+------------+--------------------------+
|app_id|country_code|recommended_advertiser_ids|
+------+------------+--------------------------+
|    32|          UK|           [9, 5, 4, 6, 3]|
+------+------------+--------------------------+

正确(我想要的)

+------+------------+--------------------------+
|app_id|country_code|recommended_advertiser_ids|
+------+------------+--------------------------+
|    32|          UK|           [9, 6, 1, 5, 4]|
+------+------------+--------------------------+

但当我执行groupby和collect时,groupby会扰乱顺序或之前的窗口操作,因此我会收集列表[9,5,4,6,3],而不是[9,6,1,5,4]。
我该怎么做?
如果我这样做了

windowResult
      .withColumn(
        "recommended_advertiser_ids",
        collect_list("advertiser_id").over(window)
      )
      .show()

它给

+------+------------+-------------+-----------------+----+--------------------------+
|app_id|country_code|advertiser_id|             rate|rank|recommended_advertiser_ids|
+------+------------+-------------+-----------------+----+--------------------------+
|    32|          UK|            9|              8.0|   1|                       [9]|
|    32|          UK|            6|              6.1|   2|                    [9, 6]|
|    32|          UK|            1|6.099999999999999|   3|                 [9, 6, 1]|
|    32|          UK|            5|              5.5|   4|           [9, 6, 1, 5, 4]|
|    32|          UK|            4|              5.5|   4|           [9, 6, 1, 5, 4]|
+------+------------+-------------+-----------------+----+--------------------------+

但我只想要每组的最后一个(app\u id,country\u code)。
更新:我修复了它恢复到我原来的代码。i、 e.切换回 df.where(rank<=n) instead of limit(n) 但我切换的原因是,如果列具有相同的值,秩<=n可以给出n个以上的结果。所以问题是如何从排名中选出前n名(每组前n名)?

nom7f22z

nom7f22z1#

原来问题的答案是。只要替换一下 limit(n) with df.where("rank<=${n}") . 问题是我是怎么计算军衔的。我使用了window rank()函数,它可以为相同的值提供相同的秩。所以我在答案中得到了n个以上的值。
更新后的问题的答案是:如何获得n个值是使用另一个窗口函数row\u number(),它根据窗口增加行数。
变化如此之大 limit() to where and rank() to row_number() solved both.

axkjgtzd

axkjgtzd2#

使用内置spark monotonically_increasing_id 函数,然后通过对 monotonically_increasing_id 列以保留顺序。 Example: ```
df.show()
//+------+------------+-------------+
//|app_id|country_code|advertiser_id|
//+------+------------+-------------+
//| 32| UK| 9|
//| 32| UK| 6|
//| 32| UK| 1|
//| 32| UK| 5|
//| 32| UK| 4|
//+------+------------+-------------+
import org.apache.spark.sql.functions._

df.withColumn("mid",monotonically_increasing_id()).
groupBy("app_id","country_code"). agg(sort_array(collect_list(struct(col("mid"),col("advertiser_id")))).alias("sor")).
selectExpr("app_id","country_code","""transform(sor,x -> x.advertiser_id) as recommended_advertiser_ids""").
show()

//+------+------------+--------------------------+
//|app_id|country_code|recommended_advertiser_ids|
//+------+------------+--------------------------+
//| 32| UK| [9, 6, 1, 5, 4]|
//+------+------------+--------------------------+

另一种方法是使用 `window function` 只过滤掉 `max array size` .

import org.apache.spark.sql.expressions._
val w=Window.orderBy(monotonically_increasing_id())
val df2=df.withColumn("cl",collect_list(col("advertiser_id")).over(w))
val max_size=df2.selectExpr("max(size(cl))").collect()(0)(0).toString.toInt

val cols=Seq("app_id","country_code","cl")
df2.filter(size(col("cl"))===max_size).
select(cols.head,cols.tail:_*).
show()
//+------+------------+---------------+
//|app_id|country_code| cl|
//+------+------------+---------------+
//| 32| UK|[9, 6, 1, 5, 4]|
//+------+------------+---------------+

相关问题