如何在pyspark中聚合数组中的值?

lmyy7pcs  于 2021-05-29  发布在  Spark
关注(0)|答案(2)|浏览(353)

spark版本3.0
我有一个这样的Dataframe

+-------------------------------------------------+
|named_val                                        |
+-------------------------------------------------+
|[[Alex, 1], [is, 1], [a, 1], [good, 1], [boy, 1]]|
|[[Bob, 1], [Bob, 1], [bad, 1], [Bob, 1]]         |
+-------------------------------------------------+

我需要创建一个带有唯一值计数的Map,如下所示
预期产量

+-------------------------------------------------+
|named_val                                        |
+-------------------------------------------------+
|{Alex->1, is->1, a->1, good->1, boy->1}          |
|{Bob->3, bad->1}                                 |
+-------------------------------------------------+

要复制代码,请使用

df = spark.createDataFrame([([['Alex', 1], ['is', 1], ['a', 1], ['good', 1], ['boy', 1]],),([['Bob', 1], ['Bob', 1], ['bad', 1], ['Bob', 1]],)],['named_val'])
7hiiyaii

7hiiyaii1#

在scala中,但python版本将非常相似:

val df =  Seq(Seq(("Alex",1),("is",1),("a",1),("good",1),("boy",1)),Seq(("Bob",1),("Bob",1),("bad",1),("Bob",1))).toDF()
df.show(false)
+-------------------------------------------------+
|value                                            |
+-------------------------------------------------+
|[[Alex, 1], [is, 1], [a, 1], [good, 1], [boy, 1]]|
|[[Bob, 1], [Bob, 1], [bad, 1], [Bob, 1]]         |
+-------------------------------------------------+

df.withColumn("id",monotonicallyIncreasingId)
.select('id,explode('value))
.select('id,'col.getField("_1").as("val"))
.groupBy('id,'val).agg(count('val).as("ct"))
.select('id,map('val,'ct).as("map"))
.groupBy('id).agg(collect_list('map))
.show(false)

+---+-----------------------------------------------------------+
|id |collect_list(map)                                          |
+---+-----------------------------------------------------------+
|0  |[[is -> 1], [Alex -> 1], [boy -> 1], [a -> 1], [good -> 1]]|
|1  |[[bad -> 1], [Bob -> 3]]                                   |
+---+-----------------------------------------------------------+
iswrvxsc

iswrvxsc2#

我们的老朋友udf呢?与洗牌相比,se/de成本应较低:

from pyspark.sql.functions import udf

def sum_merge(ar):
  d = dict()
  for i in ar:
    k, v = i[0], int(i[1])    
    d[k] = d[k] + v if k in d else v
  return d

sum_merge_udf = udf(sum_merge)

df.select(sum_merge_udf("named_val").alias("named_val"))

# +----------------------------------+

# |named_val                         |

# +----------------------------------+

# |{a=1, Alex=1, is=1, boy=1, good=1}|

# |{bad=1, Bob=3}                    |

# +----------------------------------+

相关问题