将自定义函数应用于具有多个列的PySpark Dataframe 的最优雅的方法是什么?

nbysray5  于 2023-06-21  发布在  Spark
关注(0)|答案(1)|浏览(93)

我需要创建基于三个dataframe字段的新字段。这是有效的,但似乎效率低下:

def my_func(very_long_field_name_a, very_long_field_name_b, very_long_field_name_c):
  if very_long_field_name_a >= very_long_field_name_b and very_long_field_name_c <= very_long_field_name_b:
    return 'Y'
  elif very_long_field_name_a <= very_long_field_name_b and very_long_field_name_c >= very_long_field_name_b:
    return 'Y'
  else: 
    return 'N'

import pyspark.sql.functions as F
my_udf = F.udf(my_func)

df.withColumn('new_field', my_udf(df.very_long_field_name_a, df.very_long_field_name_b, df.very_long_field_name_c)).display()

有没有可能像这样传递 Dataframe ?我试了一下,但出现了一个错误:

def my_func(df):
  if df.very_long_field_name_a >= df.very_long_field_name_b and df.very_long_field_name_c <= df.very_long_field_name_b:
    return 'Y'
  df.elif very_long_field_name_a <= df.very_long_field_name_b and df.very_long_field_name_c >= df.very_long_field_name_b:
    return 'Y'
  else: 
    return 'N'

import pyspark.sql.functions as F
my_udf = F.udf(my_func)
df.withColumn('new_field', my_udf(df)).display()

Invalid argument, not a string or column:

我想缩短它的原因是因为我已经创建了六个新字段。复制和粘贴所有作为参数传递的字段名似乎效率很低,所以我想知道是否有更干净的方法。

a0x5cqrl

a0x5cqrl1#

要基于DataFrame中的多个列创建新字段,而不显式地将每个列作为参数传递给UDF,可以使用PySpark中的struct函数。struct函数将多个列合并为一个StructType列。下面是一个例子:

import pyspark.sql.functions as F

def my_func(row):
    if row.very_long_field_name_a >= row.very_long_field_name_b and row.very_long_field_name_c <= row.very_long_field_name_b:
        return 'Y'
    elif row.very_long_field_name_a <= row.very_long_field_name_b and row.very_long_field_name_c >= row.very_long_field_name_b:
        return 'Y'
    else:
        return 'N'

my_udf = F.udf(my_func)

# Use struct to combine the necessary columns into a single column
df = df.withColumn('combined_fields', F.struct('very_long_field_name_a', 'very_long_field_name_b', 'very_long_field_name_c'))

# Apply the UDF to the combined column
df = df.withColumn('new_field', my_udf(F.col('combined_fields')))

# Drop the temporary combined column
df = df.drop('combined_fields')

df.display()

在这种方法中,我们使用struct函数将必要的列(very_long_field_name_avery_long_field_name_bvery_long_field_name_c)组合成一个名为combined_fields的列。然后,我们使用my_udf(F.col('combined_fields'))将UDF应用于combined_fields列。最后,我们使用df.drop('combined_fields')删除临时组合列。
通过使用struct,可以避免将每一列作为参数显式传递给UDF,从而使代码更简洁、更高效。

相关问题