包含stringtype元素的arraytype列上的udf函数

hpxqektj  于 2021-07-13  发布在  Spark
关注(0)|答案(1)|浏览(428)

我需要一个udf函数来输入dataframe的数组列,并对其中的两个字符串元素执行相等性检查。我的Dataframe有这样一个模式。
IDDATEOPTION12021-01-06[“红色”,“绿色”]22021-01-07[“蓝色”,“蓝色”]32021-01-08[“蓝色”,“黄色”]42021-01-09nan
我试过这个:

def equality_check(options: list):
  try:
   if options[0] == options[1]:
     return 1
   else:
     return 0
  except:
     return -1
equality_udf = f.udf(equality_check, t.IntegerType())

但它抛出了索引错误。我确信options列是字符串数组。期望是:
iddateoptionsequality\u check12021-01-06[‘红色’,‘绿色’]022021-01-07[‘蓝色’,‘蓝色’]132021-01-08[‘蓝色’,‘黄色’]042021-01-09nan-1

pcrecxhr

pcrecxhr1#

你可以检查一下 options 列表已定义或其长度小于2,而不是使用try/except。下面是一个工作示例:

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

data = [
    (1, "2021-01-06", ['red', 'green']),
    (2, "2021-01-07", ['Blue', 'Blue']),
    (3, "2021-01-08", ['Blue', 'Yellow']),
    (4, "2021-01-09", None),
]
df = spark.createDataFrame(data, ["ID", "date", "options"])

def equality_check(options: list):
    if not options or len(options) < 2:
        return -1

    return int(options[0] == options[1])

equality_udf = F.udf(equality_check, IntegerType())

df1 = df.withColumn("equality_check", equality_udf(F.col("options")))
df1.show()

# +---+----------+--------------+--------------+

# | ID|      date|       options|equality_check|

# +---+----------+--------------+--------------+

# |  1|2021-01-06|  [red, green]|             0|

# |  2|2021-01-07|  [Blue, Blue]|             1|

# |  3|2021-01-08|[Blue, Yellow]|             0|

# |  4|2021-01-09|          null|            -1|

# +---+----------+--------------+--------------+

但是,我建议您不要使用自定义项,因为您可以仅使用内置函数执行相同的操作:

df1 = df.withColumn(
    "equality_check",
    F.when(F.size(F.col("options")) < 2, -1)
        .when(F.col("options")[0] == F.col("options")[1], 1)
        .otherwise(0)
)

相关问题