如何在Spark SQL中扩展内置聚合函数(使用Scala)?

yqkkidmi  于 2023-08-05  发布在  Scala
关注(0)|答案(1)|浏览(101)

基本上,最终目标是创建类似dollarSum的东西,它将返回与ROUND(SUM(col), 2)相同的值。
我使用的是Databricks runtime 10.4 LTS ML,它显然对应于Spark 3.2.1和Scala 2.12。
我能够遵循tutorial / example code for UDAFs,并使用它创建类似于内置EVERY函数的东西。但这似乎更像是ImperativeAggregate,而我想要的可能更像是DeclarativeAggregate,参见。Spark源代码中的注解。
总的来说,我还没有在网上找到任何关于如何以简单的方式扩展内置聚合函数的文档,在这种情况下,你只需要修改“完成”或“评估”步骤,甚至只需要添加额外的行为。

到目前为止我已经尝试了到目前为止我已经尝试了至少四种方法,没有一种有效。
尝试1:

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{sum, round}

object dollarSum extends Aggregator[Double, Double, Double] {

def zero: Double = sum.zero

def reduce(buffer: Double, row: Double): Double = sum.reduce

def merge(buffer1: Double, buffer2: Double) Double = sum.merge

def finish(reduction: Double): Double = {
    sum.finish(reduction)
    round(reduction, 2)
}

def bufferEncoder: Encoder[Double] = sum.bufferEncoder
def outputEncoder: Encoder[Double] = sum.outputEncoder
}

字符串

**尝试2:**我尝试从here复制粘贴修改代码。这似乎失败了,因为内置Sum类的大多数属性和方法似乎是私有的(可能是因为开发人员不希望像我这样不知道自己在做什么的人破坏代码)。但是我不知道我可以使用什么公共接口/ API来获得我想要的。

import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions.round
import org.apache.spark.sql.catalyst.expressions.EvalMode
import org.apache.spark.sql.types.DecimalType

trait dollarSum extends Sum {

  override lazy val evaluateExpression: Expression = {
    Sum.resultType match {
      case d: DecimalType =>
        val checkOverflowInSum =
          CheckOverflowInSum(Sum.sum, d, evalMode != EvalMode.ANSI, getContextOrNull())
        If(isEmpty, Literal.create(null, Sum.resultType), checkOverflowInSum)
      case _ if shouldTrackIsEmpty =>
        If(isEmpty, Literal.create(null, Sum.resultType), Sum.sum)
      case _ => round(Sum.sum, 2)
    }
  }

}


这可能仍然会失败,因为其他一些丢失的导入,但同样,由于试图访问可能不应该访问的私有方法和属性,我无法在调试中走那么远。

**尝试3:**同一文件中的try_sum源代码似乎更接近于使用“公共API”进行求和,所以我尝试复制-粘贴-修改。但是ExpressionBuilder看起来也是一个私有类,所以这也失败了。

import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression

object DollarSumExpressionBuilder extends ExpressionBuilder {
  override def build(funcName: String, expressions: Seq[Expression]): Expression = {
    val numArgs = expressions.length
    if (numArgs == 1) {
      round(Sum(expressions.head),2)
    } else {
      throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
    }
  }
}


然后,如果这起作用,我会尝试注册函数,就像在源代码中使用Spark SQL注册TRY_SUM一样,cf。但是我得到了一个关于ExpressionBuilder不存在的错误,这似乎表明它也是包的私有类,因此不是我可以用来扩展SUM的公共接口。
我也不清楚SUM构造函数的返回类型是什么,我想可能是AggregateExpression继承自Expression。我不确定round的输入类型是什么,似乎是org.apache.spark.sql.Column,如果是这样,我不知道如何从Expression转换为Column
例如,在上述

round(org.apache.spark.sql.Column((Sum(expressions.head)),2)


或者是

round(org.apache.spark.sql.functions.col((Sum(expressions.head)),2)


将能够实现所需的类型转换(似乎两者都不起作用)。

**尝试4:**沿着上面的路线,不知道需要哪些类型以及如何在它们之间转换,以及SUM的公共接口是什么,我尝试使用org.apache.spark.sql.functions.sum作为SUM的“公共接口”,但这也不起作用。

具体来说

import org.apache.spark.sql.functions.{round, sum}
import org.apache.spark.sql.Column

// originally I had `expression: org.apache.spark.sql.catalyst.expressions.Expression` but that didn't work
def dollarSum(expression: Column): Column = {round(sum(expression), 2)}


实际上不会抛出任何错误,但是当我尝试将结果对象实际注册为(n聚合)函数时,它失败了,具体来说

spark.udf.register("dollar_sum", functions.udaf(dollarSum))


不管用,也不管用

spark.udf.register("dollar_sum", functions.udf(dollarSum))

1wnzp6jl

1wnzp6jl1#

哇,这个问题里有很多有趣的东西,而且非常熟悉:Quality's agg_expr是我进入这个领域的旅程。
要构建自定义表达式,您可能需要将代码放入org.apache.spark.sql包中,例如:registerFunction.使用SparkSession示例FunctionRegistry createOrReplaceTempFunction(例如sparkSession.getActiveSession.get.sessionState.functionRegistry),您可以在会话中使用该函数。如果你需要它在Hive视图等。您必须使用SparkSessionExtensions作为scope * 和 * FunctionRegistry.builtin.registerFunction.
实际的注册ExpressionBuilder只是Seq[Expression] => Expression的别名,表示传入构造表达式的参数。
因此,根据Spark版本(内部API变化很大):

import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.catalyst.expressions.{Round, Literal, EvalMode}
    import org.apache.spark.sql.catalyst.expressions.aggregate.Sum

    SparkSession.getActiveSession.get.sessionState.functionRegistry.
      createOrReplaceTempFunction("dollarSum", exps => Round(
        Sum(exps.head, EvalMode.TRY).toAggregateExpression(), Literal(2)), "built-in")

    val seq = Seq(1.245, 242.535, 65656.234425, 2343.666)
    import sparkSession.implicits._

    seq.toDF("amount")//.selectExpr("round(sum(amount), 2)").show
      .selectExpr("dollarSum(amount)").show

字符串
NB/FYI:质量的一个明显想法是使用lambda

import com.sparkutils.quality.{LambdaFunction, Id, registerLambdaFunctions, registerQualityFunctions}
    registerQualityFunctions()
    registerLambdaFunctions(Seq(
      LambdaFunction("dollarSum", "a -> round(sum(a), 2)", Id(1,1))
    ))


然而,这失败了,因为Spark Lambda Function和AggregateFunction不容易混合。直接的FunctionRegistry路由不涉及LambdaFunction,因此可以正常工作。

相关问题