pyspark 像f.sum一样的一列乘法

ubbxdtey  于 11个月前  发布在  Spark
关注(0)|答案(3)|浏览(166)

我在数据框中有一列。我需要通过将该列中的值相乘而不是将它们相加来聚合列。

ex = spark.createDataFrame([[1,2],[4,5]],['a','b'])
ex.show()
ex.agg(f.sum('a')).show()

字符串
而不是求和,我想用类似这样的语法乘以列'a':

ex.agg(f.mul('a')).show()


我想到的变通办法是

ex.agg(f.exp(f.sum(f.log('a')))).show()


然而,计算exp(sum(log))可能不够高效,
结果应该是4.什么是最有效的方法?

nlejzf6q

nlejzf6q1#

没有内置的乘法聚合。你的解决方案对我来说似乎很有效,其他解决方案需要构建自定义聚合函数。

import pyspark.sql.functions as F
ex = spark.createDataFrame([[1,2],[4,5], [6,7], [3,2], [9,8], [4,2]],['a','b'])
ex.show()

+---+---+
|  a|  b|
+---+---+
|  1|  2|
|  4|  5|
|  6|  7|
|  3|  2|
|  9|  8|
|  4|  2|
+---+---+

# Solution 1
ex.agg(F.exp(F.sum(F.log('a')))).show()

+----------------+
|EXP(sum(LOG(a)))|
+----------------+
|          2592.0|
+----------------+

# Solution 2
from pyspark.sql.types import IntegerType

def mul_list(l):
    return reduce(lambda x,y: x*y, l)  # In Python 3, use `from functools import reduce`

udf_mul_list = F.udf(mul_list, IntegerType())
ex.agg(udf_mul_list(F.collect_list('a'))).show()

+-------------------------------+
|mul_list(collect_list(a, 0, 0))|
+-------------------------------+
|                           2592|
+-------------------------------+

# Solution 3
seqOp = (lambda local_result, row: local_result * row['a'] )
combOp = (lambda local_result1, local_result2: local_result1 * local_result2)
ex_rdd = ex.rdd
ex_rdd.aggregate( 1, seqOp, combOp)

Out[4]: 2592

字符串
现在让我们比较性能:

import random
ex = spark.createDataFrame([[random.randint(1, 10), 3] for i in range(10000)],['a','b'])

%%timeit
ex.agg(F.exp(F.sum(F.log('a')))).count()

10 loops, best of 3: 84.9 ms per loop

%%timeit
ex.agg(udf_mul_list(F.collect_list('a'))).count()

10 loops, best of 3: 78.8 ms per loop

%%timeit
ex_rdd = ex.rdd
ex_rdd.aggregate( 1, seqOp, combOp)

10 loops, best of 3: 94.3 ms per loop


在本地的一个分区上,性能看起来差不多。请在多个分区上尝试使用更大的内存。
为提高解决方案2和3的性能:构建a custom aggregation function in Scalawrap it in Python

q5lcpyga

q5lcpyga2#

当我看到python Spark API中的限制时,我总是看一下高阶函数,因为它们给予您访问可能尚未集成到PySpark的功能。此外,当您使用优化的原生Spark操作时,它们通常会对UDF提供给予更好的性能。您可以在这里阅读更多关于高阶函数的信息:https://medium.com/@danniesim/faster-and-more-concise-than-udf-spark-functions-and-higher-order-functions-with-pyspark-31d31de5fed8
对于你的问题,你可以使用f.aggegate,你可以在Spark文档中找到一些例子:https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.functions.aggregate.html#pyspark.sql.functions.aggregate。这里可以参考如何通过相乘来聚合值:

ex.agg(f.aggregate('a', f.lit(1.0), lambda acc, x: acc * x))

字符串
编辑:f.aggregate可从PySpark 3.1.0获得,如果您有以前的版本,您可以执行以下操作(同样,另一个高阶函数可以使用Spark SQL API中的aggregate:https://spark.apache.org/docs/latest/api/sql/#aggregate):

ex
.agg(f.collect_list('a').alias('a'))
.withColumn('a', f.expr("aggregate(a, CAST(1.0 AS DOUBLE), (acc, x) -> acc * x, acc -> acc)"))


像这样,你只使用了原生的spark API,但不用说,对于只在一个组上相乘的值来说,这看起来太复杂了。

相关问题