在pysparkDataframe中将重叠间隔列表拆分为非重叠子间隔,并检查值在重叠间隔上是否有效

js5cn81o  于 2021-07-13  发布在  Spark
关注(0)|答案(2)|浏览(282)

我有一个包含列的pysparkDataframe start_time , end_time 定义每行的间隔。它还包含一列 is_duplicated 设置为 True 如果一个间隔被至少另一个间隔重叠;设置为 False 如果没有。
有一列 rate ,我想知道子间隔是否有不同的值(定义上是重叠的);如果是这样,我想保留包含列中最新更新的记录 updated_at 作为基本事实。
在中间步骤中,我想创建一个列 is_validated 设置为: None 子间隔不重叠时 True 当子间隔被另一个包含不同子间隔的子间隔重叠时 rate 值,并且是最后更新的 False 当子间隔被另一个包含不同子间隔的子间隔重叠时 rate 并且不是最后更新的
注意:中间步骤不是强制性的,我提供它只是为了让解释更清楚。
输入:


# So this:

input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20          
              Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: full overlap for (2,3) with (1,4)               
              Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'),  # OVERLAP: (3,5) and (1,4) and rate=10/20                          
              Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'),  # NO OVERLAP: hole between (5,6)                                            
              Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')]  # NO OVERLAP

df = spark.createDataFrame(input_rows)
df.show()
>>> +-------------------+-------------------+----+-------------------+
    |         start_time|           end_time|rate|         updated_at|
    +-------------------+-------------------+----+-------------------+
    |2018-01-01 00:00:00|2018-01-04 00:00:00|  10|2021-02-25 00:00:00|
    |2018-01-02 00:00:00|2018-01-03 00:00:00|  10|2021-02-25 00:00:00|
    |2018-01-03 00:00:00|2018-01-05 00:00:00|  20|2021-02-20 00:00:00|
    |2018-01-06 00:00:00|2018-01-07 00:00:00|  30|2021-02-25 00:00:00|
    |2018-01-07 00:00:00|2018-01-08 00:00:00|  30|2021-02-25 00:00:00|
    +-------------------+-------------------+----+-------------------+

# Will become:

tmp_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
            Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True,  is_validated=True),
            Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True,  is_validated=True),
            Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-20 00:00:00', is_duplicated=True,  is_validated=False),
            Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=True,  is_validated=True),
            Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
            Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
            Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None)
           ]
tmp_df = spark.createDataFrame(tmp_rows)
tmp_df.show()
>>> 
+-------------------+-------------------+----+-------------------+-------------+------------+
|         start_time|           end_time|rate|         updated_at|is_duplicated|is_validated|
+-------------------+-------------------+----+-------------------+-------------+------------+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|2021-02-25 00:00:00|        false|        null|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|2021-02-25 00:00:00|         true|        true|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|2021-02-25 00:00:00|         true|        true|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|2021-02-20 00:00:00|         true|       false|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  20|2021-02-25 00:00:00|         true|        true|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|2021-02-25 00:00:00|        false|        null|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|2021-02-25 00:00:00|        false|        null|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|2021-02-25 00:00:00|        false|        null|
+-------------------+-------------------+----+-------------------+-------------+------------+

# To give you:

output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10),
               Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),
               Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20),
               Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20),
               Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),
               Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)
              ]
final_df = spark.createDataFrame(output_rows)
final_df.show()
>>> 
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+
093gszye

093gszye1#

这样做有效:

from pyspark.sql import functions as F, Row, SparkSession, SQLContext, Window
from pyspark.sql.types import BooleanType

spark = (SparkSession.builder 
    .master("local") 
    .appName("Octopus") 
    .config('spark.sql.autoBroadcastJoinThreshold', -1)
    .getOrCreate())

input_rows = [Row(idx=0, interval_start='2018-01-01 00:00:00', interval_end='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20          
              Row(idx=0, interval_start='2018-01-02 00:00:00', interval_end='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: full overlap for (2,3) with (1,4)               
              Row(idx=0, interval_start='2018-01-03 00:00:00', interval_end='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'),  # OVERLAP: (3,5) and (1,4) and rate=10/20                          
              Row(idx=0, interval_start='2018-01-06 00:00:00', interval_end='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'),  # NO OVERLAP: hole between (5,6)                                            
              Row(idx=0, interval_start='2018-01-07 00:00:00', interval_end='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')]  # NO OVERLAP

df = spark.createDataFrame(input_rows)
df.show()

# Compute overlapping intervals

sc = spark.sparkContext
sql_context = SQLContext(sc, spark)

def overlap(start_first, end_first, start_second, end_second):
    return ((start_first < start_second < end_first) or (start_first < end_second < end_first)
           or (start_second < start_first < end_second) or (start_second < end_first < end_second))
sql_context.registerFunction('overlap', overlap, BooleanType())

df.registerTempTable("df1")
df.registerTempTable("df2")
df = df.cache()

overlap_df = spark.sql("""
     SELECT df1.idx, df1.interval_start, df1.interval_end, df1.rate AS rate FROM df1 JOIN df2
     ON df1.idx == df2.idx
     WHERE overlap(df1.interval_start, df1.interval_end, df2.interval_start, df2.interval_end)
""")
overlap_df = overlap_df.cache()

# Compute NON overlapping intervals

non_overlap_df = df.join(overlap_df, ['interval_start', 'interval_end'], 'leftanti')

# Stack overlapping points

interval_point = overlap_df.select('interval_start').union(overlap_df.select('interval_end'))
interval_point = interval_point.withColumnRenamed('interval_start', 'p').distinct().sort('p')

# Construct continuous overlapping intervals

w = Window.rowsBetween(1, Window.unboundedFollowing)

interval_point = interval_point.withColumn('interval_end', F.min('p').over(w)).dropna(subset=['p', 'interval_end'])
interval_point = interval_point.withColumnRenamed('p', 'interval_start')

# Stack continuous overlapping intervals and non overlapping intervals

df3 = interval_point.select('interval_start', 'interval_end').union(non_overlap_df.select('interval_start', 'interval_end'))

# Point in interval range join

# https://docs.databricks.com/delta/join-performance/range-join.html

df3.registerTempTable("df3")
df.registerTempTable("df")
sql = """SELECT df3.interval_start, df3.interval_end, df.rate, df.updated_at
         FROM df3 JOIN df ON df3.interval_start BETWEEN df.interval_start and df.interval_end - INTERVAL 1 seconds"""
df4 = spark.sql(sql)
df4.sort('interval_start').show()

# select non overlapped intervals and keep most up to date rate value for overlapping intervals

(df4.groupBy('interval_start', 'interval_end')
    .agg(F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate'))
    .orderBy("interval_start")).show()

+-------------------+-------------------+----+
|     interval_start|       interval_end|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+
h9a6wy2h

h9a6wy2h2#

您可以分解时间戳序列,就像您的中间Dataframe一样,然后按开始时间和结束时间分组,以根据更新时间获得最新速率。

import pyspark.sql.functions as F

output = df.selectExpr(
    """
    inline(arrays_zip(
        sequence(timestamp(start_time), timestamp(end_time) - interval 1 day, interval 1 day),
        sequence(timestamp(start_time) + interval 1 day, timestamp(end_time), interval 1 day)
    )) as (start_time, end_time)
    """,
    "rate", "updated_at"
).groupBy(
    'start_time', 'end_time'
).agg(
    F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate')
).orderBy("start_time")

output.show()
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

相关问题