pyspark 根据Spark中前面的行计算列的乘积

sczxawaw  于 2022-11-01  发布在  Spark
关注(0)|答案(1)|浏览(122)

我有一个Spark Dataframe ,我想根据前面行中的2列计算后面行的值。我知道如何只计算1行(使用lag()函数),但我不知道如何将前面行中的这些值传递到后面的几行。

id | month | value | monthly_increment
1  | 01    | 100   | 2
1  | 02    | 200   | 3
1  | 03    | 600   | 4
1  | 04    | 2400  | 2

正如您所看到的,列“value”的值乘以“monthly_increment”,并且它一直影响该特定“id”的所有后续值。
如何使用PySpark实现这一点?

vsaztqbk

vsaztqbk1#

在询问Spark问题时,提供示例输入 Dataframe 是非常重要的,你没有提供,所以我假设你的输入 Dataframe 看起来像这样:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('1', '01',  100, 2),
     ('1', '02', None, 3),
     ('1', '03', None, 4),
     ('1', '04', None, 2)],
    ['id', 'month', 'value', 'monthly_increment'])

Spark3.2+**
您可以使用productlagfirst窗口函数的组合来填充缺少的列“值”值:

w = W.partitionBy('id').orderBy('month')
factor = F.product(F.lag('monthly_increment').over(w)).over(w)
df = df.withColumn('value', F.coalesce(F.first('value').over(w) * factor, 'value'))

df.show()

# +---+-----+------+-----------------+

# | id|month| value|monthly_increment|

# +---+-----+------+-----------------+

# |  1|   01| 100.0|                2|

# |  1|   02| 200.0|                3|

# |  1|   03| 600.0|                4|

# |  1|   04|2400.0|                2|

# +---+-----+------+-----------------+

相关问题