在大数据上优化pyspark udf

70gysomp  于 2021-07-13  发布在  Spark
关注(0)|答案(1)|浏览(330)

我试图优化这个代码,当列的值(pysparkDataframe的)在[categories]中时创建一个伪值。
当运行在10万行上时,运行大约需要30秒。在我的情况下,我有大约2000万行,这将需要很多时间。

def create_dummy(dframe,col_name,top_name,categories,**options):
    lst_tmp_col = []
    if 'lst_tmp_col' in options:
        lst_tmp_col = options["lst_tmp_col"]
    udf = UserDefinedFunction(lambda x: 1 if x in categories else 0, IntegerType())
    dframe  = dframe.withColumn(str(top_name), udf(col(col_name))).cache()
    dframe = dframe.select(lst_tmp_col+ [str(top_name)])
    return dframe

换言之,如何优化此函数并减少与数据量相关的总时间?如何确保这个函数不会遍历我的数据?
谢谢你的建议。谢谢

qij5mzcb

qij5mzcb1#

编码类别不需要自定义项。你可以用 isin :

import pyspark.sql.functions as F

def create_dummy(dframe, col_name, top_name, categories,**options):
    lst_tmp_col = []
    if 'lst_tmp_col' in options:
        lst_tmp_col = options["lst_tmp_col"]
    dframe = dframe.withColumn(str(top_name), F.col(col_name).isin(categories).cast("int")).cache()
    dframe = dframe.select(lst_tmp_col + [str(top_name)])
    return dframe

相关问题