pySpark Dataframe 中的累积乘积

50pmv0ei  于 2023-10-15  发布在  Spark
关注(0)|答案(4)|浏览(112)

下面是Spark DataFrame:

+---+---+
|  a|  b|
+---+---+     
|  1|  1|  
|  1|  2|  
|  1|  3|
|  1|  4|
+---+---+

我想创建另一个名为"c"的列,其中包含“B”与“a”的乘积。生成的DataFrame应该如下所示:

+---+---+---+
|  a|  b|  c|
+---+---+---+     
|  1|  1|  1|
|  1|  2|  2|
|  1|  3|  6|
|  1|  4| 24|
+---+---+---+

如何做到这一点?

pjngdqdw

pjngdqdw1#

下面是一种不使用用户定义函数的替代方法

df = spark.createDataFrame([(1, 1), (1, 2), (1, 3), (1, 4), (1, 5)], ['a', 'b'])
wind = Window.partitionBy("a").rangeBetween(Window.unboundedPreceding, Window.currentRow).orderBy("b")
df2 = df.withColumn("foo", collect_list("b").over(wind))
df2.withColumn("foo2", expr("aggregate(foo, cast(1 as bigint), (acc, x) -> acc * x)")).show()

+---+---+---------------+----+
|  a|  b|            foo|foo2|
+---+---+---------------+----+
|  1|  1|            [1]|   1|
|  1|  2|         [1, 2]|   2|
|  1|  3|      [1, 2, 3]|   6|
|  1|  4|   [1, 2, 3, 4]|  24|
|  1|  5|[1, 2, 3, 4, 5]| 120|
+---+---+---------------+----+

如果你真的不在乎精度,你可以建立一个更短的版本,

import pyspark.sql.functions as psf

df.withColumn("foo", psf.exp(psf.sum(psf.log("b")).over(wind))).show()
+---+---+------------------+
|  a|  b|               foo|
+---+---+------------------+
|  1|  1|               1.0|
|  1|  2|               2.0|
|  1|  3|               6.0|
|  1|  4|23.999999999999993|
|  1|  5|119.99999999999997|
+---+---+------------------
p8ekf7hl

p8ekf7hl2#

你必须设置一个命令栏。在你的例子中,我用了“B”栏

from pyspark.sql import functions as F, Window, types
from functools import reduce
from operator import mul

df = spark.createDataFrame([(1, 1), (1, 2), (1, 3), (1, 4), (1, 5)], ['a', 'b'])

order_column = 'b'

window = Window.orderBy(order_column)

expr = F.col('a') * F.col('b')

mul_udf = F.udf(lambda x: reduce(mul, x), types.IntegerType())

df = df.withColumn('c', mul_udf(F.collect_list(expr).over(window)))

df.show()

+---+---+---+
|  a|  b|  c|
+---+---+---+
|  1|  1|  1|
|  1|  2|  2|
|  1|  3|  6|
|  1|  4| 24|
|  1|  5|120|
+---+---+---+
pwuypxnk

pwuypxnk3#

from pyspark.sql import functions as F, Window
df = spark.createDataFrame([(1, 1), (1, 2), (1, 3), (1, 4), (1, 5)], ['a', 'b'])

window = Window.orderBy('b').rowsBetween(Window.unboundedPreceding,Window.currentRow)

df = df.withColumn('c', F.product(F.col('b')).over(window))

df.show()

+---+---+-----+
|  a|  b|    c|
+---+---+-----+
|  1|  1|  1.0|
|  1|  2|  2.0|
|  1|  3|  6.0|
|  1|  4| 24.0|
|  1|  5|120.0|
+---+---+-----+
nqwrtyyt

nqwrtyyt4#

你的回答与此类似。

import pandas as pd
df = pd.DataFrame({'v':[1,2,3,4,5,6]})
df['prod'] = df.v.cumprod()
   v   prod
0  1     1
1  2     2
2  3     6
3  4    24
4  5   120
5  6   720

相关问题