scala—如何在sparkDataframe中查找更改发生点

drnojrws  于 2021-07-14  发布在  Spark
关注(0)|答案(1)|浏览(269)

我以一个简单的Dataframe为例:

val someDF = Seq(
  (1, "A"),
  (2, "A"),
  (3, "A"),
  (4, "B"),
  (5, "B"),
  (6, "A"),
  (7, "A"),
  (8, "A")
).toDF("t", "state")

// this part is half pseudocode
someDF.aggregate((acc, cur) => {
    if (acc.last.state != cur.state) {
        acc.add(cur)
    }
}, List()).show(truncate=false)

“t”列表示时间点,“state”列表示该时间点的状态。
我希望找到的是每一个变化发生的第一时间加上第一行,如:

(1, "A")
(4, "B")
(6, "A")

我也看过sql中的解决方案,但是它们涉及复杂的自连接和窗口函数,我不完全理解,但是sql解决方案也可以。
spark中有很多函数(fold、aggregate、reduce…)我觉得它们可以做到这一点,但是我不能理解其中的区别,因为我对spark的概念(比如分区)还比较陌生,如果分区会影响结果,那就有点棘手了。

lnlaulya

lnlaulya1#

你可以使用窗口功能 lag 用于与上一行比较,以及 row_number 检查是否为第一行:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val result = someDF.withColumn(
    "change", 
    lag("state", 1).over(Window.orderBy("t")) =!= col("state") || 
    row_number().over(Window.orderBy("t")) === 1
).filter("change").drop("change")

result.show
+---+-----+
|  t|state|
+---+-----+
|  1|    A|
|  4|    B|
|  6|    A|
+---+-----+

对于sql解决方案:

someDF.createOrReplaceTempView("mytable")
val result = spark.sql("""
    select t, state 
    from (
        select 
            t, state, 
            lag(state) over (order by t) != state or row_number() over (order by t) = 1 as change 
       from mytable
    ) 
    where change
""")

相关问题