Spark自定义聚合:collect_list+自定义项与UDAF

3htmauhk  于 2023-01-09  发布在  Apache
关注(0)|答案(1)|浏览(198)

在Spark 2.1中,我经常需要对 Dataframe 执行自定义聚合,并使用了以下两种方法:

  • 使用groupby/collect_list获取单行中的所有值,然后应用UDF来聚合这些值
  • 编写定制的UDAF(用户定义的聚合函数)

我通常更喜欢第一种选择,因为它比UDAF实现更容易实现和更可读。但是我会假设第一种选择通常更慢,因为更多的数据在网络上发送(没有部分聚合),但是我的经验表明UDAF通常很慢。为什么呢?
具体示例:计算直方图
数据位于配置单元表中(1 E6个随机双精度值)

val df = spark.table("testtable")

def roundToMultiple(d:Double,multiple:Double) = Math.round(d/multiple)*multiple

UDF方法:

val udf_histo = udf((xs:Seq[Double]) => xs.groupBy(x => roundToMultiple(x,0.25)).mapValues(_.size))

df.groupBy().agg(collect_list($"x").as("xs")).select(udf_histo($"xs")).show(false)

+--------------------------------------------------------------------------------+
|UDF(xs)                                                                         |
+--------------------------------------------------------------------------------+
|Map(0.0 -> 125122, 1.0 -> 124772, 0.75 -> 250819, 0.5 -> 248696, 0.25 -> 250591)|
+--------------------------------------------------------------------------------+

UDAF方法

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.mutable

class HistoUDAF(binWidth:Double) extends UserDefinedAggregateFunction {

  override def inputSchema: StructType =
    StructType(
      StructField("value", DoubleType) :: Nil
    )

  override def bufferSchema: StructType =
    new StructType()
      .add("histo", MapType(DoubleType, IntegerType))

  override def deterministic: Boolean = true
  override def dataType: DataType = MapType(DoubleType, IntegerType)
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map[Double, Int]()
  }
  
  private def mergeMaps(a: Map[Double, Int], b: Map[Double, Int]) = {
    a ++ b.map { case (k,v) => k -> (v + a.getOrElse(k, 0)) }
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     val oldBuffer = buffer.getAs[Map[Double, Int]](0)
     val newInput = Map(roundToMultiple(input.getDouble(0),binWidth) -> 1) 
     buffer(0) = mergeMaps(oldBuffer, newInput)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val a = buffer1.getAs[Map[Double, Int]](0)
    val b = buffer2.getAs[Map[Double, Int]](0)
    buffer1(0) = mergeMaps(a, b)
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Map[Double, Int]](0)
  }
}

val histo = new HistoUDAF(0.25)

df.groupBy().agg(histo($"x")).show(false)

+--------------------------------------------------------------------------------+
|histoudaf(x)                                                                    |
+--------------------------------------------------------------------------------+
|Map(0.0 -> 125122, 1.0 -> 124772, 0.75 -> 250819, 0.5 -> 248696, 0.25 -> 250591)|
+--------------------------------------------------------------------------------+

我的测试表明collect_list/UDF方法比UDAF方法快2倍,这是一个普遍的规律,还是在某些情况下,UDAF确实快得多,而相当笨拙的实现是合理的?

ergxz8rk

ergxz8rk1#

UDAF比较慢,因为它在每一行的每一次更新-〉时都要从/向内部缓冲区反序列化/序列化聚合器,这是相当昂贵的(some more details),相反,您应该使用Aggregator(事实上,从Spark 3.0开始,UDAF就是deprecated)。

相关问题