pysparkDataframe条件(按窗口/延迟)

9udxz4iz  于 2021-05-16  发布在  Spark
关注(0)|答案(1)|浏览(566)

我对spark和sparkDataframe还不熟悉。我有一个sparkDataframe,比如:


# For sake of simplicity only one user (uid) is shown, but there are multiple users

+-------------------+-----+-------+
|start_date         |uid  |count  |
+-------------------+-----+-------+
|2020-11-26 08:30:22|user1|  4    |
|2020-11-26 10:00:00|user1|  3    |
|2020-11-22 08:37:18|user1|  3    |
|2020-11-22 13:32:30|user1|  2    |
|2020-11-20 16:04:04|user1|  2    |
|2020-11-16 12:04:04|user1|  1    |

如果用户在过去至少有count>=x个事件,我想创建一个新的布尔列,其中的值为true/false,并用true标记这些事件。例如,对于x=3,我希望得到:

+-------------------+-----+-------+--------------+
|start_date         |uid  |count  | marked_event |
+-------------------+-----+-------+--------------+
|2020-11-26 08:30:22|user1|  4    |  True        |
|2020-11-26 10:00:00|user1|  3    |  True        |
|2020-11-22 08:37:18|user1|  3    |  True        |
|2020-11-22 13:32:30|user1|  2    |  True        |
|2020-11-20 16:04:04|user1|  2    |  True        |
|2020-11-16 12:04:04|user1|  1    |  False       |

也就是说,对于每个计数>=3,我需要将该事件标记为true,以及前面的3个事件。只有user1的最后一个事件是false,因为我在start\u date=2020-11-22 08:37:18的事件之前(包括)标记了3个事件。
有什么办法吗?我的直觉是以某种方式使用窗口/滞后来实现这一点,但我是新的Spark,不知道如何做到这一点。。。
谢谢!
编辑:
我在@mck的解决方案上使用了一个变体,并修复了一个小错误:原来的解决方案有:

F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing))

条件,不管是否满足“count”的条件,它都会在“begin”之后标记所有事件。相反,我更改了解决方案,以便窗口只标记“begin”之前发生的事件:

event = (f.max(f.col('begin')).over(w.rowsBetween(-2, 0))).\ 
          alias('event_post_only') 

# the number of events to mark is 3 from 'begin',

# including the event itself, so that's -2.

df_marked_events = df_marked_events.select('*', event)

然后为所有在“event\u post\u only”中为true或在“event\u post\u only”中为true的事件标记true

df_marked_events = df_marked_events.withColumn('event', (col('count') >= 3) \
                       | (col('event_post_only')))

这避免了将上游的所有内容都标记为true以“begin”==true

b09cbbtk

b09cbbtk1#

import pyspark.sql.functions as F
from pyspark.sql.window import Window

w = Window.partitionBy('uid').orderBy(F.col('count').desc(), F.col('start_date'))

# find the beginning point of >= 3 events

begin = (
    (F.col('count') >= 3) &
    (F.lead(F.col('count')).over(w) < 3)
).alias('begin')
df = df.select('*', begin)

# Mark as event if the event is in any rows after begin, or two rows before begin

event = (
    F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing)) | 
    F.max(F.col('begin')).over(w.rowsBetween(-2,0))
).alias('event')
df = df.select('*', event)

df.show()
+-------------------+-----+-----+-----+-----+
|         start_date|  uid|count|begin|event|
+-------------------+-----+-----+-----+-----+
|2020-11-26 08:30:22|user1|  4.0|false| true|
|2020-11-22 08:37:18|user1|  3.0|false| true|
|2020-11-26 10:00:00|user1|  3.0| true| true|
|2020-11-20 16:04:04|user1|  2.0|false| true|
|2020-11-22 13:32:30|user1|  2.0|false| true|
|2020-11-16 12:04:04|user1|  1.0|false|false|
+-------------------+-----+-----+-----+-----+

相关问题