我想按元素对一列数组中的数组求和-这列数组应该聚合为一个数组。下面的代码给出了所需的结果[3,6,9],但它使用了一个UDF,在缩放时会导致OOM。我希望有同样的结果,但纯粹在Spark!
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType,ArrayType
def sum_arr_by_element(l):
return [sum(x) for x in zip(*l)]
my_udf = udf(sum_arr_by_element, ArrayType(IntegerType()))
data = [
([1, 2, 3],),
([1, 2, 3],),
([1, 2, 3],),
]
df = spark.createDataFrame(data, ["array_column"])
df.agg(F.collect_list("array_column").alias('all_lists')).withColumn('summed',my_udf("all_lists")).select('summed').display()
3条答案
按热度按时间db2dz4w81#
试试这个:
ovfsdjhp2#
这里有一种方法,首先按行聚合数组的元素,然后按元素在原始列表中的位置顺序收集元素
jpfvwuh43#
你可以在Spark中使用阵列的最大功率。
如果你有一个分组列,收集组内的所有数组,并使用
aggregate
函数。该函数将保留元素位置。这里有一个例子