无法在Spark/PySpark中创建数组文字

daupos2t  于 2023-05-06  发布在  Spark
关注(0)|答案(3)|浏览(115)

我在尝试从一个基于两列项目列表的 Dataframe 中删除行时遇到了麻烦。例如,对于此dataframe:

df = spark.createDataFrame([(100, 'A', 304), (200, 'B', 305), (300, 'C', 306)], ['number', 'letter', 'id'])
df.show()
# +------+------+---+
# |number|letter| id|
# +------+------+---+
# |   100|     A|304|
# |   200|     B|305|
# |   300|     C|306|
# +------+------+---+

我可以很容易地在一列上使用isin删除行:

df.where(~col('number').isin([100, 200])).show()
# +------+------+---+
# |number|letter| id|
# +------+------+---+
# |   300|     C|306|
# +------+------+---+

但是当我尝试将它们删除两列时,我得到了一个异常:

df.where(~array('number', 'letter').isin([(100, 'A'), (200, 'B')])).show()
Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.lit.
: java.lang.RuntimeException: Unsupported literal type class java.util.ArrayList [100, A]
    at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:57)
    at org.apache.spark.sql.functions$.lit(functions.scala:101)
    at org.apache.spark.sql.functions.lit(functions.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:280)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:214)
    at java.lang.Thread.run(Thread.java:745)

经过一些调查,我意识到问题的根本原因是从非原始类型创建文字。我在PySpark中尝试了以下代码:

lit((100, 'A'))
lit([100, 'A'])

在Scala中:

lit((100, "A"))
lit(List(100, "A"))
lit(Seq(100, "A"))
lit(Array(100, "A"))

但没有成功有人知道如何在Spark/PySpark中创建文字数组吗?或者是否有其他方法通过两列过滤 Dataframe ?

nhaq1z21

nhaq1z211#

首先,您可能需要struct而不是arrays。请记住,Spark SQL不支持异构阵列,因此array(1, 'a')被转换为array<string>
因此,query可能看起来像这样:

choices = [(100, 'A'), (200, 'B')]

target = [
    struct(
        lit(number).alias("number").cast("long"), 
        lit(letter).alias("letter").cast("string")) 
    for number, letter  in choices]

query = struct("number", "letter").isin(target)

这似乎生成了有效的表达式:

query
Column<b'(named_struct(NamePlaceholder(), number, NamePlaceholder(), letter) IN (named_struct(col1, CAST(100 AS `number` AS BIGINT), col2, CAST(A AS `letter` AS STRING)), named_struct(col1, CAST(200 AS `number` AS BIGINT), col2, CAST(B AS `letter` AS STRING))))'>

但由于某种原因,分析仪上未通过:

df.where(~query)
AnalysisException                         Traceback (most recent call last)
...
AnalysisException: "cannot resolve '(named_struct('number', `number`, 'letter', `letter`) IN (named_struct('col1', CAST(100 AS BIGINT), 'col2', CAST('A' AS STRING)), named_struct('col1', CAST(200 AS BIGINT), 'col2', CAST('B' AS STRING))))' due to data type mismatch: Arguments must be same type;;\n'Filter NOT named_struct(number, number#0L, letter, letter#1) IN (named_struct(col1, cast(100 as bigint), col2, cast(A as string)),named_struct(col1, cast(200 as bigint), col2, cast(B as string)))\n+- LogicalRDD [number#0L, letter#1, id#2L]\n"

奇怪的是,SQL following也会失败:

df.createOrReplaceTempView("df")

spark.sql("SELECT * FROM df WHERE struct(letter, letter) IN (struct(CAST(1 AS bigint), 'a'))")
AnalysisException: "cannot resolve '(named_struct('letter', df.`letter`, 'letter', df.`letter`) IN (named_struct('col1', CAST(1 AS BIGINT), 'col2', 'a')))' due to data type mismatch: Arguments must be same type; line 1 pos 46;\n'Project [*]\n+- 'Filter named_struct(letter, letter#1, letter, letter#1) IN (named_struct(col1, cast(1 as bigint), col2, a))\n   +- SubqueryAlias df\n      +- LogicalRDD [number#0L, letter#1, id#2L]\n"

但是当两边都用文字替换时:

spark.sql("SELECT * FROM df WHERE struct(CAST(1 AS bigint), 'a') IN (struct(CAST(1 AS bigint), 'a'))")
DataFrame[number: bigint, letter: string, id: bigint]

工作正常所以看起来像个bug。
也就是说,left anti join应该在这里工作得很好:

+------+------+---+
|number|letter| id|
+------+------+---+
|   300|     C|306|
+------+------+---+
hiz5n14c

hiz5n14c2#

要在spark中创建一个数组字面量,你需要从一系列列中创建一个数组,其中一个列是从lit函数中创建的:

scala> array(lit(100), lit("A"))
res1: org.apache.spark.sql.Column = array(100, A)
kzipqqlq

kzipqqlq3#

Spark 3.4+

F.lit([5, 7])

完整示例:

from pyspark.sql import functions as F

df = spark.range(2)
df = df.withColumn("c1", F.lit([5, 7]))

df.show()
# +---+------+
# | id|    c1|
# +---+------+
# |  0|[5, 7]|
# |  1|[5, 7]|
# +---+------+

相关问题