Pyspark对具有连续编号的行进行分组(具有重复项)[重复项]

xytpbqjk  于 2023-05-21  发布在  Spark
关注(0)|答案(1)|浏览(155)

此问题已在此处有答案

Create the array of integer with consecutive number in PySpark(1个答案)
3天前关闭。
我有一个数据,其中有一个连续的时间槽的客户到达

df = spark.createDataFrame(
    [(0, 'A'),
     (1, 'B'),
     (1, 'C'),
     (5, 'D'),
     (8, 'A'),
     (9, 'F'),
     (20, 'T'),
     (20, 'S'),
     (21, 'C')],
    ['time_slot', 'customer'])
+--------+--------+
|time_slot|customer|
+--------+--------+
|       0|       A|
|       1|       B|
|       1|       C|
|       5|       D|
|       8|       A|
|       9|       F|
|      20|       T|
|      20|       S|
|      21|       C|
+--------+--------+

我需要按顺序时隙(包括重复时隙)对客户进行分组,以便获得:

+--------------------+---------------------------------------------+
|       grouped_slots|                            grouped_customers|
+--------------------+---------------------------------------------+
|              [0, 1]|                              ['A', 'B', 'C']|
|                 [5]|                                        ['D']|
|              [8, 9]|                                   ['A', 'F']|
|            [20, 21]|                              ['T', 'S', 'C']|
+--------------------+---------------------------------------------+

我曾尝试使用滞后功能来查看prev记录,但不知道如何根据该分组

window = W.orderBy("time_slot")
df = df.withColumn("prev_time_slot", F.lag(F.col('time_slot'), 1).over(window))
+---------+--------+--------------+
|time_slot|customer|prev_time_slot|
+---------+--------+--------------+
|        0|       A|          null|
|        1|       B|             0|
|        1|       C|             1|
|        5|       D|             1|
|        8|       A|             5|
|        9|       F|             8|
|       20|       T|             9|
|       20|       S|            20|
|       21|       C|            20|
+---------+--------+--------------+
cnwbcb6i

cnwbcb6i1#

该代码首先计算包含前一行的时隙的滞后列。然后,每当当前时隙和前一个时隙之间的差大于1时,它就标记新序列的开始。它使用累积和为每个序列创建唯一的组ID。最后,它按照这些组ID对DataFrame进行分组,将时隙和客户聚合到列表中:

window = Window.orderBy("time_slot")

df = df.withColumn("prev_time_slot", F.lag(F.col('time_slot')).over(window))

df = df.withColumn("isNewSequence", 
                   (F.col("time_slot") - F.col("prev_time_slot") > 1).cast("int"))

df = df.withColumn("groupId", F.sum("isNewSequence").over(window))

df_grouped = df.groupBy("groupId").agg(F.collect_set("time_slot").alias("grouped_slots"), 
                                        F.collect_set("customer").alias("grouped_customers"))

df_grouped.show(truncate=False)

相关问题