scala—如何编写spark udf,它以array[structtype],structtype作为输入并返回array[structtype]

cclgggtu  于 2021-05-29  发布在  Spark
关注(0)|答案(2)|浏览(548)

我有一个具有以下模式的Dataframe:

root
 |-- user_id: string (nullable = true)
 |-- user_loans_arr: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- loan_date: string (nullable = true)
 |    |    |-- loan_amount: string (nullable = true)
 |-- new_loan: struct (nullable = true)
 |    |-- loan_date : string (nullable = true)
 |    |-- loan_amount : string (nullable = true)

我想使用一个自定义项,它将user\u loans\u arr和new\u loans作为输入,并将new\u loans结构添加到现有的user\u loans\u arr中。然后,从user\u loans\u arr中删除所有贷款日期早于12个月的元素。
提前谢谢。

eit6fx6z

eit6fx6z1#

您需要将数组和结构列作为数组或结构传递给udf。我更喜欢把它作为struct传递。在那里,您可以操作元素并返回数组类型。

import pyspark.sql.functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import *
import numpy as np

# Test data

tst = sqlContext.createDataFrame([(1,2,3,4),(3,4,5,1),(5,6,7,8),(7,8,9,2)],schema=['col1','col2','col3','col4'])
tst_1=(tst.withColumn("arr",F.array('col1','col2'))).withColumn("str",F.struct('col3','col4'))

# udf to return array

@udf(ArrayType(StringType()))
def fn(row):
    if(row.arr[1]>row.str.col4):
        res=[]
    else:
        res.append(row.str[i])        
        res = row.arr+row.str.asDict().values()        
    return(res)

# calling udf with a struct of array and struct column

tst_fin = tst_1.withColumn("res",fn(F.struct('arr','str')))

结果是

tst_fin.show()
+----+----+----+----+------+------+------------+
|col1|col2|col3|col4|   arr|   str|         res|
+----+----+----+----+------+------+------------+
|   1|   2|   3|   4|[1, 2]|[3, 4]|[1, 2, 4, 3]|
|   3|   4|   5|   1|[3, 4]|[5, 1]|          []|
|   5|   6|   7|   8|[5, 6]|[7, 8]|[5, 6, 8, 7]|
|   7|   8|   9|   2|[7, 8]|[9, 2]|          []|
+----+----+----+----+------+------+----------

本例将所有内容都视为int,因为字符串是date,所以在udf中必须使用python的datetime函数进行比较。

ryhaxcpt

ryhaxcpt2#

如果 spark >= 2.4 那么您不需要自定义项,请检查下面的示例-

加载输入数据

val df = spark.sql(
      """
        |select user_id, user_loans_arr, new_loan
        |from values
        | ('u1', array(named_struct('loan_date', '2019-01-01', 'loan_amount', 100)), named_struct('loan_date',
        | '2020-01-01', 'loan_amount', 100)),
        | ('u2', array(named_struct('loan_date', '2020-01-01', 'loan_amount', 200)), named_struct('loan_date',
        | '2020-01-01', 'loan_amount', 100))
        | T(user_id, user_loans_arr, new_loan)
      """.stripMargin)
    df.show(false)
    df.printSchema()

    /**
      * +-------+-------------------+-----------------+
      * |user_id|user_loans_arr     |new_loan         |
      * +-------+-------------------+-----------------+
      * |u1     |[[2019-01-01, 100]]|[2020-01-01, 100]|
      * |u2     |[[2020-01-01, 200]]|[2020-01-01, 100]|
      * +-------+-------------------+-----------------+
      *
      * root
      * |-- user_id: string (nullable = false)
      * |-- user_loans_arr: array (nullable = false)
      * |    |-- element: struct (containsNull = false)
      * |    |    |-- loan_date: string (nullable = false)
      * |    |    |-- loan_amount: integer (nullable = false)
      * |-- new_loan: struct (nullable = false)
      * |    |-- loan_date: string (nullable = false)
      * |    |-- loan_amount: integer (nullable = false)
      */

按以下要求加工

将user\u loans\u arr和new\u loans作为输入,并将new\u loans结构添加到现有user\u loans\u arr中。然后,从user\u loans\u arr中删除贷款日期早于12个月的所有元素。 spark >= 2.4 ```
df.withColumn("user_loans_arr",
expr(
"""
|FILTER(array_union(user_loans_arr, array(new_loan)),
| x -> months_between(current_date(), to_date(x.loan_date)) < 12)
""".stripMargin))
.show(false)

/**
  * +-------+--------------------------------------+-----------------+
  * |user_id|user_loans_arr                        |new_loan         |
  * +-------+--------------------------------------+-----------------+
  * |u1     |[[2020-01-01, 100]]                   |[2020-01-01, 100]|
  * |u2     |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
  * +-------+--------------------------------------+-----------------+
  */

`spark < 2.4`
// spark < 2.4
val outputSchema = df.schema("user_loans_arr").dataType

import java.time._
val add_and_filter = udf((userLoansArr: mutable.WrappedArray[Row], loan: Row) => {
  (userLoansArr :+ loan).filter(row => {
    val loanDate = LocalDate.parse(row.getAs[String]("loan_date"))
    val period = Period.between(loanDate, LocalDate.now())
    period.getYears * 12 + period.getMonths < 12
  })
}, outputSchema)

df.withColumn("user_loans_arr", add_and_filter($"user_loans_arr", $"new_loan"))
  .show(false)

/**
  * +-------+--------------------------------------+-----------------+
  * |user_id|user_loans_arr                        |new_loan         |
  * +-------+--------------------------------------+-----------------+
  * |u1     |[[2020-01-01, 100]]                   |[2020-01-01, 100]|
  * |u2     |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
  * +-------+--------------------------------------+-----------------+
  */

相关问题