在没有udf的情况下,如何在pyspark上将多个数组交到单个数组中

syqv5f0l  于 2021-07-13  发布在  Spark
关注(0)|答案(2)|浏览(320)

我有以下代码:

elements = spark.createDataFrame([
('g1', 'a', 1), ('g1', 'a', 2), ('g1', 'b', 1), ('g1', 'b', 3),
('g2', 'c', 1), ('g2', 'c', 3), ('g2', 'c', 2), ('g2', 'd', 4),
], ['group', 'instance', 'element'])

all_elements_per_instance = elements.groupBy("group", "instance").agg(f.collect_set('element').alias('elements'))

# +-----+--------+---------+

# |group|instance| elements|

# +-----+--------+---------+

# |   g1|       b|   [1, 3]|

# |   g1|       a|   [1, 2]|

# |   g2|       c|[1, 2, 3]|

# |   g2|       d|      [4]|

# +-----+--------+---------+

@f.udf(ArrayType(IntegerType()))
def intersect(elements):
    return list(functools.reduce(lambda x, y: set(x).intersection(set(y)), elements))

all_intersect_elements_per_group = all_elements_per_instance.groupBy("group")\
    .agg(intersect(f.collect_list("elements")).alias("intersection"))

# +-----+------------+

# |group|intersection|

# +-----+------------+

# |   g1|         [1]|

# |   g2|          []|

# +-----+------------+

有没有办法避免使用udf(因为它很昂贵),并以某种方式使用 f.array_intersect 或者类似于聚合函数的函数?

7lrncoxx

7lrncoxx1#

你可以使用高阶函数 aggregate 做一个 array_intersect 关于要素:

import pyspark.sql.functions as f
result = all_elements_per_instance.groupBy('group').agg(
    f.expr("""
        aggregate(
            collect_list(elements),
            collect_list(elements)[0],
            (acc, x) -> array_intersect(acc, x)
        ) as intersection
    """)
)

result.show()
+-----+------------+
|group|intersection|
+-----+------------+
|   g2|          []|
|   g1|         [1]|
+-----+------------+
camsedfj

camsedfj2#

如果你想找到 elements 至少由2个共享的 instances 在每个 group ,您实际上可以通过使用窗口然后使用groupby来计算每个组/元素的不同示例来简化它 group 仅收集计数大于1的元素:

from pyspark.sql import Window
from pyspark.sql import functions as F

result = elements.withColumn(
    "cnt",
    F.size(F.collect_set("instance").over(Window.partitionBy("group", "element")))
).groupBy("group").agg(
    F.collect_set(
        F.when(F.col("cnt") > 1, F.col("element"))
    ).alias('intersection')
)

result.show()

# +-----+------------+

# |group|intersection|

# +-----+------------+

# |   g2|          []|

# |   g1|         [1]|

# +-----+------------+

我曾经 collect_set + size 作为函数 countDistinct 不支持窗口。

相关问题