python—检索pyspark中Dataframe的每组中的前n个

r8xiu3jd  于 2021-07-09  发布在  Spark
关注(0)|答案(6)|浏览(466)

Pypark中有一个Dataframe,数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

我所期望的是在每个组中返回2条具有相同用户id的记录,这需要有最高的分数。因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

我对pyspark真的很陌生,有人能给我一个代码片段或者这个问题的相关文档的入口吗?太好了,谢谢!

8fsztsew

8fsztsew1#

在pyspark sqlquery中使用 ROW_NUMBER() 功能:

SELECT * FROM (
    SELECT e.*, 
    ROW_NUMBER() OVER (ORDER BY col_name DESC) rn 
    FROM Employee e
)
WHERE rn = N

n是该列所需的第n个最大值
输出:

[Stage 2:>               (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name   |
+-----------+
|1183395    |
+-----------+

查询将返回n个最大值

tkclm6bt

tkclm6bt2#

如果使用 row_number 而不是 rank 获得等级相等时:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()

注意 limit(20).toPandas() 用诡计代替 show() 为了更好的格式,jupyter笔记本。

yduiuuwa

yduiuuwa3#

使用Python3和spark 2.4

from pyspark.sql import Window
import pyspark.sql.functions as f

def get_topN(df, group_by_columns, order_by_column, n=1):
    window_group_by_columns = Window.partitionBy(group_by_columns)
    ordered_df = df.select(df.columns + [
        f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
    topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
    return topN_df

top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1)
vc6uscn9

vc6uscn94#

我知道有人问这个问题 pyspark 我在寻找类似的答案 Scala
检索scala中Dataframe每组中的前n个值
这是你的名字 scala @mtoto答案的版本。

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 

# you can change the value 2 to any number you want. Here 2 represents the top 2 values

更多的例子可以在这里找到。

nhaq1z21

nhaq1z215#

下面是另一个没有窗口函数的解决方案,用于从pysparkDataframe获取前n条记录。


# Import Libraries

from pyspark.sql.functions import col

# Sample Data

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

# Get top n records as Row Objects

row_list = df.orderBy(col("score").desc()).head(5)

# Convert row objects to DF

sorted_df = spark.createDataFrame(row_list)

# Display DataFrame

sorted_df.show()

输出

+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2|    2|
| user_2| object_2|    2|
| user_1| object_1|    3|
| user_2| object_1|    5|
| user_2| object_2|    6|
+-------+---------+-----+

如果您对spark中的更多窗口功能感兴趣,可以参考我的博客:https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86

14ifxucb

14ifxucb6#

我认为您需要使用窗口函数来获得基于 user_id 以及 score ,然后过滤结果以仅保留前两个值。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 

# +-------+---------+-----+----+

# |user_id|object_id|score|rank|

# +-------+---------+-----+----+

# | user_1| object_1|    3|   1|

# | user_1| object_2|    2|   2|

# | user_2| object_2|    6|   1|

# | user_2| object_1|    5|   2|

# +-------+---------+-----+----+

总的来说,官方的编程指南是学习spark的好地方。

数据

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

相关问题