pyspark:基于条件的窗口求和

mm5n2pyu  于 2021-07-12  发布在  Spark
关注(0)|答案(1)|浏览(458)

考虑一下简单的Dataframe:

from pyspark import SparkContext
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder.appName('Trial').getOrCreate()

simpleData = (("2000-04-17", "144", 1), \
    ("2000-07-06", "015", 1),  \
    ("2001-01-23", "015", -1),   \
    ("2001-01-18", "144", -1),  \
    ("2001-04-17", "198", 1),    \
    ("2001-04-18", "036", -1),  \
    ("2001-04-19", "012", -1),    \
    ("2001-04-19", "188", 1), \
    ("2001-04-25", "188", 1),\
    ("2001-04-27", "015", 1) \
  )

columns= ["dates", "id", "eps"]
df = spark.createDataFrame(data = simpleData, schema = columns)
df.printSchema()
df.show(truncate=False)

输出:

root
 |-- dates: string (nullable = true)
 |-- id: string (nullable = true)
 |-- eps: long (nullable = true)

+----------+---+---+
|dates     |id |eps|
+----------+---+---+
|2000-04-17|144|1  |
|2000-07-06|015|1  |
|2001-01-23|015|-1 |
|2001-01-18|144|-1 |
|2001-04-17|198|1  |
|2001-04-18|036|-1 |
|2001-04-19|012|-1 |
|2001-04-19|188|1  |
|2001-04-25|188|1  |
|2001-04-27|015|1  |
+----------+---+---+

我想把 eps 滚动窗口上的列,只保留 id 列。例如,定义一个5行的窗口,假设我们在2001-04-17,我只想求最后一行的和 eps 每个给定唯一id的值。在5行中,我们只有3个不同的id,因此总和必须是3个元素:-id 144(第四行)为1,-id 015(第三行)为1,id 198(第五行)为1,总计为-1。
在我的脑海里,在滚动的窗口里,我应该做一些像 F.sum(groupBy('id').agg(F.last('eps'))) 这当然不可能在滚动窗口中实现。
我使用自定义项获得了所需的结果。

@pandas_udf(IntegerType(), PandasUDFType.GROUPEDAGG)
def fun_sum(id, eps):
    df = pd.DataFrame()
    df['id'] = id
    df['eps'] = eps
    value = df.groupby('id').last().sum()
    return value

然后:

w = Window.orderBy('dates').rowsBetween(-5,0)
df = df.withColumn('sum', fun_sum(F.col('id'), F.col('eps')).over(w))

问题是,我的数据集包含超过800万行,使用这个udf执行此任务需要大约2个小时。
我想知道是否有一种方法可以通过内置的pyspark函数实现相同的结果,避免使用udf,或者至少有一种方法可以提高我的udf的性能。
为完整起见,所需输出应为:

+----------+---+---+----+
|dates     |id |eps|sum |
+----------+---+---+----+
|2000-04-17|144|1  |1   |
|2000-07-06|015|1  |2   |
|2001-01-23|015|-1 |0   |
|2001-01-18|144|-1 |-2  |
|2001-04-17|198|1  |-1  |
|2001-04-18|036|-1 |-2  |
|2001-04-19|012|-1 |-3  |
|2001-04-19|188|1  |-1  |
|2001-04-25|188|1  |0   |
|2001-04-27|015|1  |0   |
+----------+---+---+----+

编辑:使用 .rangeBetween() Windows。

vwkv1x7d

vwkv1x7d1#

如果你还没弄明白,这里有一个方法。
假设 df 定义和初始化的方式与您在问题中定义和初始化的方式相同。
导入所需的函数和类:

from pyspark.sql.functions import row_number, col
from pyspark.sql.window import Window

创建必要的 WindowSpec :

window_spec = (
    Window
    # Partition by 'id'.
    .partitionBy(df.id)
    # Order by 'dates', latest dates first.
    .orderBy(df.dates.desc())
)

创建 DataFrame 使用分区数据:

partitioned_df = (
    df
    # Use the window function 'row_number()' to populate a new column
    # containing a sequential number starting at 1 within a window partition.
    .withColumn('row', row_number().over(window_spec))
    # Only select the first entry in each partition (i.e. the latest date).
    .where(col('row') == 1)
)

以防你想再次检查数据:

partitioned_df.show()

# +----------+---+---+---+

# |     dates| id|eps|row|

# +----------+---+---+---+

# |2001-04-19|012| -1|  1|

# |2001-04-25|188|  1|  1|

# |2001-04-27|015|  1|  1|

# |2001-04-17|198|  1|  1|

# |2001-01-18|144| -1|  1|

# |2001-04-18|036| -1|  1|

# +----------+---+---+---+

分组并聚合数据:

sum_rows = (
    partitioned_df
    # Aggragate data.
    .groupBy()
    # Sum all rows in 'eps' column.
    .sum('eps')
    # Get all records as a list of Rows.
    .collect()
)

得到结果:

print(f"sum eps: {sum_rows[0][0]})

# sum eps: 0

相关问题