dataframe—在sparkscala中,按分组后,在df中跨用户计数元素

lyr7nygr  于 2021-07-09  发布在  Spark
关注(0)|答案(1)|浏览(375)

我有这个数据框:

|User     |country|
|   Ron   |  italy| 
|   Tom   |  japan|
|   Lin   |  spain|
|   Tom   |  china|
|   Tom   |  china|
|   Lin   |  japan|
|   Tom   |  china|
|   Lin   |  japan|

我想计算每个用户的国家总数。例如,对于上面的df,我将得到:

[Ron ->  [italy ->1], Tom -> [Japan -> 1, china -> 3], Lin -> [Spain -> 1, Japan ->2]]

我从

val groupedbyDf = df.groupBy("User")

但我不知道如何继续。。agg()?

vshtjzan

vshtjzan1#

您需要在分组后使用相关Map功能创建Map:

val df2 = df.groupBy("User", "country")
  .count()
  .groupBy("User")
  .agg(map(
      col("User"), 
      map_from_entries(collect_list(struct(col("country"), col("count"))))
      ).as("result")
  )
  .select("result")

df2.show(false)
+---------------------------------+
|result                           |
+---------------------------------+
|[Tom -> [china -> 3, japan -> 1]]|
|[Lin -> [spain -> 1, japan -> 2]]|
|[Ron -> [italy -> 1]]            |
+---------------------------------+

如果要将它们全部放在一行中,可以再进行一次聚合:

val df2 = df.groupBy("User", "country")
  .count()
  .groupBy("user")
  .agg(map_from_entries(collect_list(struct(col("country"), col("count")))).as("result"))
  .agg(map_from_entries(collect_list(struct(col("user"), col("result")))).as("result_all"))

df2.show(false)
+---------------------------------------------------------------------------------------+
|result_all                                                                             |
+---------------------------------------------------------------------------------------+
|[Tom -> [china -> 3, japan -> 1], Lin -> [spain -> 1, japan -> 2], Ron -> [italy -> 1]]|
+---------------------------------------------------------------------------------------+

相关问题