pyspark 分解多个长度可变的数组列

ih99xse1  于 2022-11-01  发布在  Spark
关注(0)|答案(4)|浏览(111)

如何分解具有可变长度和潜在空值的多个数组列?
我的输入数据如下所示:

+----+------------+--------------+--------------------+
|col1|        col2|          col3|                col4|
+----+------------+--------------+--------------------+
|   1|[id_1, id_2]|  [tim, steve]|       [apple, pear]|
|   2|[id_3, id_4]|       [jenny]|           [avocado]|
|   3|        null|[tommy, megan]| [apple, strawberry]|
|   4|        null|          null|[banana, strawberry]|
+----+------------+--------------+--------------------+

我需要将其分解为:
1.具有相同索引的数组项Map到同一行
1.如果一列中只有一个条目,则它将应用于每个分解的行
1.如果数组为空,则应用于每一行
我的输出应该如下所示:

+----+----+-----+----------+
|col1|col2|col3 |col4      |
+----+----+-----+----------+
|1   |id_1|tim  |apple     |
|1   |id_2|steve|pear      |
|2   |id_3|jenny|avocado   |
|2   |id_4|jenny|avocado   |
|3   |null|tommy|apple     |
|3   |null|megan|strawberry|
|4   |null|null |banana    |
|4   |null|null |strawberry|
+----+----+-----+----------+

我已经能够使用下面的代码实现这一点,但我觉得必须有一个更直接的方法:

df = spark.createDataFrame(
    [
        (1, ["id_1", "id_2"], ["tim", "steve"], ["apple", "pear"]),
        (2, ["id_3", "id_4"], ["jenny"], ["avocado"]),
        (3, None, ["tommy", "megan"], ["apple", "strawberry"]),
        (4, None, None, ["banana", "strawberry"])
    ],
    ["col1", "col2", "col3", "col4"]
)

df.createOrReplaceTempView("my_table")

spark.sql("""
with cte as (
  SELECT
  col1,
  col2,
  col3,
  col4,
  greatest(size(col2), size(col3), size(col4)) as max_array_len
  FROM my_table
), arrays_extended as (
select 
col1,
case 
  when col2 is null then array_repeat(null, max_array_len) 
  else col2 
end as col2, 
case
  when size(col3) = 1 then array_repeat(col3[0], max_array_len)
  when col3 is null then array_repeat(null, max_array_len)
  else col3
end as col3,
case
  when size(col4) = 1 then array_repeat(col4[0], max_array_len)
  when col4 is null then array_repeat(null, max_array_len)
  else col4
end as col4
from cte), 
arrays_zipped as (
select *, explode(arrays_zip(col2, col3, col4)) as zipped
from arrays_extended
)
select 
  col1,
  zipped.col2,
  zipped.col3,
  zipped.col4
from arrays_zipped
""").show(truncate=False)
sq1bmfud

sq1bmfud1#

得到max_array_len后,只需使用sequence函数遍历数组,将它们转换为一个struct,然后分解得到的struct数组,见下面的SQL:

spark.sql("""
  with cte as (
    SELECT      
      col1,         
      col2,
      col3,
      col4,
      greatest(size(col2), size(col3), size(col4)) as max_array_len
    FROM my_table
  )
  SELECT inline_outer(
           transform(
             sequence(0,max_array_len-1), i -> (
               col1 as col1,
               col2[i] as col2,
               coalesce(col3[i], col3[0]) as col3,             /* fill null with the first array item of col3 */
               coalesce(col4[i], element_at(col4,-1)) as col4  /* fill null with the last array item of col4 */
             )
           )
         )
  FROM cte
""").show()
+----+----+-----+----------+
|col1|col2| col3|      col4|
+----+----+-----+----------+
|   1|id_1|  tim|     apple|
|   1|id_2|steve|      pear|
|   2|id_3|jenny|   avocado|
|   2|id_4|jenny|   avocado|
|   3|null|tommy|     apple|
|   3|null|megan|strawberry|
|   4|null| null|    banana|
|   4|null| null|strawberry|
+----+----+-----+----------+

类似的问题here

wko9yo5t

wko9yo5t2#

您可以将inline_outer与selectExpr结合使用,另外还可以将coalesce用于第一个非空值,以处理不同数组中的大小不匹配

数据准备

inp_data = [
    (1,['id_1', 'id_2'],['tim', 'steve'],['apple', 'pear']),
    (2,['id_3', 'id_4'],['jenny'],['avocado']),
    (3,None,['tommy','megan'],['apple', 'strawberry']),
    (4,None,None,['banana', 'strawberry'])
]

inp_schema = StructType([
                      StructField('col1',IntegerType(),True)
                     ,StructField('col2',ArrayType(StringType(), True))
                     ,StructField('col3',ArrayType(StringType(), True))
                     ,StructField('col4',ArrayType(StringType(), True))
                    ]
                   )

sparkDF = sql.createDataFrame(data=inp_data,schema=inp_schema)\

sparkDF.show(truncate=False)

+----+------------+--------------+--------------------+
|col1|col2        |col3          |col4                |
+----+------------+--------------+--------------------+
|1   |[id_1, id_2]|[tim, steve]  |[apple, pear]       |
|2   |[id_3, id_4]|[jenny]       |[avocado]           |
|3   |null        |[tommy, megan]|[apple, strawberry] |
|4   |null        |null          |[banana, strawberry]|
+----+------------+--------------+--------------------+

内联外部

sparkDF.selectExpr("col1"
                   ,"""inline_outer(arrays_zip(
                                       coalesce(col2,array()),
                                       coalesce(col3,array()),
                                       coalesce(col4,array())
                                    )
                )""").show(truncate=False)

+----+----+-----+----------+
|col1|0   |1    |2         |
+----+----+-----+----------+
|1   |id_1|tim  |apple     |
|1   |id_2|steve|pear      |
|2   |id_3|jenny|avocado   |
|2   |id_4|null |null      |
|3   |null|tommy|apple     |
|3   |null|megan|strawberry|
|4   |null|null |banana    |
|4   |null|null |strawberry|
+----+----+-----+----------+
ssgvzors

ssgvzors3#

您可以使用UDF function

from pyspark.sql import functions as F, types as T

cols_of_interest = [c for c in df.columns if c != 'col1']

@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
def get_sequences(*cols):
    """Equivalent of arrays_zip, but handling different lengths of the arrays.
       For shorter array than the maximum length last element is repeated.
    """

    # Get the length of the longest array in the row
    max_len = max(map(len, filter(lambda x: x, cols)))

    return list(zip(*[
        # create a list for each column with a length equal to the max_len.
        # If the original column has less elements than needed, repeat the last one.
        # None values will be filled with a list of Nones with length max_len.
        [c[min(i, len(c) - 1)]  for i in range(max_len)] if c else [None] * max_len for c in cols
    ]))

df2 = (
    df
    .withColumn('temp', F.explode(get_sequences(*cols_of_interest)))
    .select('col1', 
            *[F.col('temp').getItem(i).alias(c) for i, c in enumerate(cols_of_interest)])
)

df2为下列DataFrame

+----+----+-----+----------+
|col1|col2| col3|      col4|
+----+----+-----+----------+
|   1|id_1|  tim|     apple|
|   1|id_2|steve|      pear|
|   2|id_3|jenny|   avocado|
|   2|id_4|jenny|   avocado|
|   3|null|tommy|     apple|
|   3|null|megan|strawberry|
|   4|null| null|    banana|
|   4|null| null|strawberry|
+----+----+-----+----------+
qmb5sa22

qmb5sa224#

我用了你的逻辑,把它缩短了一点。

import pyspark.sql.functions as func

arrcols = ['col2', 'col3', 'col4']

data_sdf. \
    selectExpr(*['coalesce({0}, array()) as {0}'.format(c) if c in arrcols else c for c in data_sdf.columns]). \
    withColumn('max_size', func.greatest(*[func.size(c) for c in arrcols])). \
    selectExpr('col1', 
               *['flatten(array({0}, array_repeat(element_at({0}, -1), max_size-size({0})))) as {0}'.format(c) for c in arrcols]
               ). \
    withColumn('arrzip', func.arrays_zip(*arrcols)). \
    selectExpr('col1', 'inline(arrzip)'). \
    orderBy('col1', 'col2'). \
    show()

# +----+----+-----+----------+

# |col1|col2| col3|      col4|

# +----+----+-----+----------+

# |   1|id_1|  tim|     apple|

# |   1|id_2|steve|      pear|

# |   2|id_3|jenny|   avocado|

# |   2|id_4|jenny|   avocado|

# |   3|null|megan|strawberry|

# |   3|null|tommy|     apple|

# |   4|null| null|    banana|

# |   4|null| null|strawberry|

# +----+----+-----+----------+

接近步骤

  • 用空数组填充空值,并取所有数组列中的最大值
  • 将元素添加到比其他数组小的数组中
  • 我取了数组的最后一个元素,并对它使用了array_repeat(与您的方法类似)
  • 要重复的次数是通过检查最大大小与正在处理的数组的大小(max_size-size({0}))来计算的
  • 通过上述步骤,您现在将在每个数组列中具有相同数量的元素,这使您能够压缩(arrays_zip)它们并展开(使用inline() sql函数)

第二个selectExpr中的列表解析生成以下内容

['flatten(array({0}, array_repeat(element_at({0}, -1), max_size-size({0})))) as {0}'.format(c) for c in arrcols]

# ['flatten(array(col2, array_repeat(element_at(col2, -1), max_size-size(col2)))) as col2',

# 'flatten(array(col3, array_repeat(element_at(col3, -1), max_size-size(col3)))) as col3',

# 'flatten(array(col4, array_repeat(element_at(col4, -1), max_size-size(col4)))) as col4']

如果有帮助的话,以下是Spark生成的优化逻辑计划和物理计划

== Optimized Logical Plan ==
Generate inline(arrzip#363), [1], false, [col2#369, col3#370, col4#371]
+- Project [col1#0L, arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4) AS arrzip#363]
   +- Filter (size(arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4), true) > 0)
      +- LogicalRDD [col1#0L, col2#1, col3#2, col4#3], false

== Physical Plan ==
Generate inline(arrzip#363), [col1#0L], false, [col2#369, col3#370, col4#371]
+- *(1) Project [col1#0L, arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4) AS arrzip#363]
   +- *(1) Filter (size(arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4), true) > 0)
      +- *(1) Scan ExistingRDD[col1#0L,col2#1,col3#2,col4#3]

相关问题