scala Spark在满足条件列中获得最小值

pexxcrt2  于 2023-10-18  发布在  Scala
关注(0)|答案(3)|浏览(139)

我在spark中有一个DataFrame,看起来像这样:

id |  flag
----------
 0 |  true
 1 |  true
 2 | false
 3 |  true
 4 |  true
 5 |  true
 6 | false
 7 | false
 8 |  true
 9 | false

我想获取另一个列,如果它有flag == false,或者下一个false值的rowNumber,那么它的当前rowNumber,所以输出会像这样:

id |  flag | nextOrCurrentFalse
-------------------------------
 0 |  true |                  2
 1 |  true |                  2
 2 | false |                  2
 3 |  true |                  6
 4 |  true |                  6
 5 |  true |                  6
 6 | false |                  6
 7 | false |                  7
 8 |  true |                  9
 9 | false |                  9

我想以向量化的方式来做这件事(而不是按行迭代)。所以我实际上想要的逻辑是:

  • 对于每一行,获取大于或等于当前行的最小id,该id具有标志== false
y53ybaqx

y53ybaqx1#

在考虑了扩展等问题之后--但不清楚Catalyst是否足够好--我提出了一个解决方案,该解决方案基于一个可以从分区中受益的答案,并且只需考虑数据即可。它是关于预计算和处理的,一些按摩可以击败蛮力方法。你关于JOIN的观点不是一个问题,因为这是一个有界的JOIN,没有大量的数据生成。
你对框架方法的评论有点厌倦,因为所有超过这里的都是框架。我想你的意思是你想循环通过一个 Dataframe ,并有一个子循环与出口。我找不到这样的例子,事实上,我不确定它是否符合SPARK范式。同样的结果,更少的处理:

import org.apache.spark.sql.functions._
import spark.implicits._
import org.apache.spark.sql.expressions.Window

val df = Seq((0, true), (1, true), (2,false), (3, true), (4,true), (5,true), (6,false), (7,false), (8,true), (9,false)).toDF("id","flag")
@transient val  w1 = org.apache.spark.sql.expressions.Window.orderBy("id1")  

val ids = df.where("flag = false") 
            .select($"id".as("id1"))  

val ids2 = ids.select($"*", lag("id1",1,-1).over(w1).alias("prev_id"))
val ids3 = ids2.withColumn("prev_id1", col("prev_id")+1).drop("prev_id")

// Less and better performance at scale, this is better theoretically for Catalyst to bound partitions? Less work to do in any event.
// Some understanding of data required! And no grouping and min.
val withNextFalse = df.join(ids3, df("id") >= ids3("prev_id1") && df("id") <= ids3("id1"))
                     .select($"id", $"flag", $"id1".alias("nextOrCurrentFalse"))
                     .orderBy(asc("id"),asc("id"))

withNextFalse.show(false)

还返回:

+---+-----+------------------+
|id |flag |nextOrCurrentFalse|
+---+-----+------------------+
|0  |true |2                 |
|1  |true |2                 |
|2  |false|2                 |
|3  |true |6                 |
|4  |true |6                 |
|5  |true |6                 |
|6  |false|6                 |
|7  |false|7                 |
|8  |true |9                 |
|9  |false|9                 |
+---+-----+------------------+
x6492ojm

x6492ojm2#

如果flag相当稀疏,你可以这样做:

val ids = df.where("flag = false"). 
             select($"id".as("id1"))  

val withNextFalse = df.join(ids, df("id") <= ids("id1")).
                      groupBy("id", "flag").
                      agg("id1" -> "min")

在第一步中,我们为标志为false的id创建一个框架。然后,我们在所需的条件下将该嵌套框架连接到原始数据(原始id应小于或等于flag为false的行的id)。
要获得 first 这种情况,请按id分组并使用agg查找id1的最小值(这是flag = false的行的id)。
在示例数据上运行(并按id排序)会得到所需的输出:

+---+-----+--------+
| id| flag|min(id1)|
+---+-----+--------+
|  0| true|       2|
|  1| true|       2|
|  2|false|       2|
|  3| true|       6|
|  4| true|       6|
|  5| true|       6|
|  6|false|       6|
|  7|false|       7|
|  8| true|       9|
|  9|false|       9|
+---+-----+--------+

如果DataFrame非常大,并且有许多行的标志为False,则这种方法可能会遇到性能问题。如果是这种情况,您可能会更好地使用迭代解决方案。

ukqbszuj

ukqbszuj3#

看看其他的答案哪个更好,但是把这个留在这里是为了SQL教育的目的-可能的。
这是你想要的,但我很想知道其他人对这一点的看法。我将检查催化剂,看看它是如何工作的程序,但我认为这可能意味着一些错过分区边界,我热衷于检查,以及。

import org.apache.spark.sql.functions._
val df = Seq((0, true), (1, true), (2,false), (3, true), (4,true), (5,true), (6,false), (7,false), (8,true), (9,false)).toDF("id","flag")
df.createOrReplaceTempView("tf") 

// Performance? Need to check at some stage how partitioning works in such a case.
spark.sql("CACHE TABLE tf") 
val res1 = spark.sql("""  
                       SELECT tf1.*, tf2.id as id2, tf2.flag as flag2
                         FROM tf tf1, tf tf2
                        WHERE tf2.id  >= tf1.id
                          AND tf2.flag = false 
                     """)    

//res1.show(false)
res1.createOrReplaceTempView("res1") 
spark.sql("CACHE TABLE res1") 

val res2 = spark.sql(""" SELECT X.id, X.flag, X.id2 
                           FROM (SELECT *, RANK() OVER (PARTITION BY id ORDER BY id2 ASC) as rank_val 
                                   FROM res1) X
                          WHERE X.rank_val = 1
                       ORDER BY id
                    """) 

res2.show(false)

相关问题