如何有效地过滤字典中列出条件的PySpark Dataframe ?

7hiiyaii  于 2023-05-06  发布在  Spark
关注(0)|答案(2)|浏览(119)

如何在不使用任何for循环的情况下有效地过滤字典中列出条件的PySpark Dataframe ?例如,我有如下的 Dataframe (df)。x1c 0d1x我有一个条件字典,其中键为product_pt_family,值为acceptance_rate,即,

dict1= {'Fruits & Vegetables’: 85, 'Dairy & Eggs': 90,'Water':91,'Bakery':92}

我想过滤掉满足这些条件的 Dataframe ,即获取接受率超过对应product_pt_family的字典值的 Dataframe 行。
也就是说,如果pt_family是“Fruits & Vegetables”,那么只获取那些接受率〉= 85的行;如果pt_family是“Water”,则只获取接受率〉= 91的行,依此类推。
因此,使用该字典过滤后的最终 Dataframe 为:

我尝试用字典获取这个最终 Dataframe 的代码是

pt_families=list(set(df.select('product_pt_family').toPandas()['product_pt_family']))

schema= df.schema

thresholded_df=spark.createDataFrame([], schema)

for pt_family in  pt_families:

        df_1=df.filter((df.product_pt_family==pt_family ) & (df.acceptance_rate>=dict1[pt_family]))

        thresholded_df=thresholded_df.union(df_1)
     
thresholded_df.show()


这需要很多时间。有没有快速有效的方法来过滤数据框而不使用for循环和联合?
谢谢

kxkpmulp

kxkpmulp1#

不需要使用UDF,除非不可避免,否则应避免使用UDF,因为Catalyst Optimizer不知道它们的实现,因此会影响计划。Spark建议使用内置
这将工作:

dict1= {'Fruits & Vegetables': 85, 'Dairy & Eggs': 90}

conditions = (F.lit(1) == F.lit(0)) # Or conditions thus won't impact
for family, rate in dict1.items():
      conditions = (conditions) | ((F.col("product_pt_family") == family) & (F.col("acceptance_rate") >= rate))
        
df\
 .filter(conditions)\
 .show(truncate=False)

输入:

+-------+-------------------+---------------+
|item_id|product_pt_family  |acceptance_rate|
+-------+-------------------+---------------+
|1      |Fruits & Vegetables|88.96          |
|2      |Fruits & Vegetables|80.96          |
|3      |Dairy & Eggs       |80.0           |
|4      |Fruits & Vegetables|91.0           |
|5      |Dairy & Eggs       |95.0           |
+-------+-------------------+---------------+

输出:

+-------+-------------------+---------------+
|item_id|product_pt_family  |acceptance_rate|
+-------+-------------------+---------------+
|1      |Fruits & Vegetables|88.96          |
|4      |Fruits & Vegetables|91.0           |
|5      |Dairy & Eggs       |95.0           |
+-------+-------------------+---------------+
dz6r00yl

dz6r00yl2#

我更新了我的结果,没有使用Spark UDF和任何循环语句。
让我们假设这是你的DataFrame(df):

+--------+-------------------+---------------+
| item_id|  product_pt_family|acceptance_rate|
+--------+-------------------+---------------+
|51259338|Fruits & Vegetables|          81.45|
|22660282|Fruits & Vegetables|           98.5|
|10450119|       Dairy & Eggs|          89.65|
|10450118|       Dairy & Eggs|          90.32|
+--------+-------------------+---------------+

1.将字典转换为DataFrame

threshold_dict = {'Fruits & Vegetables': 85, 'Dairy & Eggs': 90}

threshold_df = spark.createDataFrame(threshold_dict.items(), schema="product_pt_family STRING, threshold_rate INT")

1.使用join获取所需的输出

joined_df = df.join(threshold_df, "product_pt_family")

out_df = joined_df.filter("acceptance_rate >= threshold_rate").drop("threshold_rate")

out_df.show()

输出:

+-------------------+--------+---------------+
|  product_pt_family| item_id|acceptance_rate|
+-------------------+--------+---------------+
|       Dairy & Eggs|10450118|          90.32|
|Fruits & Vegetables|22660282|           98.5|
+-------------------+--------+---------------+

相关问题