spark 3.0中使用聚合器的通用udaf

xmakbtuz  于 2021-05-27  发布在  Spark
关注(0)|答案(2)|浏览(580)

spark 3.0已弃用 UserDefinedAggregateFunction 我试着用 Aggregator . 基本用法 Aggregator 很简单,但是我很难找到更通用的函数版本。
我将尝试用这个例子来解释我的问题,这个例子是 collect_set . 这不是我的实际情况,但更容易解释问题:

class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
  override def zero = Set.empty
  override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
  override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
  override def finish(reduction: Set[Int]) = reduction
  override def bufferEncoder = Encoders.kryo[Set[Int]]
  override def outputEncoder = ExpressionEncoder()
}

// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()

我更喜欢 .toColumn.udf.register ,但这不是重点。
问题:我不能使这个聚合器的通用版本,它将只与整数。
我尝试过:

class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]]

它因错误而崩溃:

No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
    at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)

我不能和你一起去 CollectSetDemo[T] ,以防我无法正常工作 outputEncoder . 另外,在使用udaf时,我只能处理spark数据类型、列等。

ruyhziif

ruyhziif1#

用泛型修改@ramunas答案:

class CollectSetDemoAgg[T: TypeTag](name: String) extends Aggregator[Row, Set[T], Seq[T]] {
  override def zero = Set.empty
  override def reduce(b: Set[T], a: Row) = b + a.getAs[T](a.fieldIndex(name))
  override def merge(b1: Set[T], b2: Set[T]) = b1 ++ b2
  override def finish(reduction: Set[T]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[T]]

  override def outputEncoder = {
    val tt = typeTag[Seq[T]]
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = serializerForType(tpe)
    val deserializer = deserializerForType(tpe)

    new ExpressionEncoder[Seq[T]](serializer, deserializer, ClassTag[Seq[T]](cls))
  }
}
p3rjfoxz

p3rjfoxz2#

我还没有找到一个很好的方法来解决这个问题,但我还是有办法解决的。代码部分借用自 RowEncoder :

class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
  override def zero = Set.empty
  override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
  override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
  override def finish(reduction: Set[Any]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[Any]]

  // now
  override def outputEncoder = {
    val mirror = ScalaReflection.mirror
    val tt = fieldType match {
      case ArrayType(LongType, _) => typeTag[Seq[Long]]
      case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
      case ArrayType(StringType, _) => typeTag[Seq[String]]
      // .. etc etc
      case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
    }
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = ScalaReflection.serializerForType(tpe)
    val deserializer = ScalaReflection.deserializerForType(tpe)

    new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
  }
}

我必须添加的一件事是聚合器中的结果数据类型参数。用法改为:

df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()

我真的不喜欢结果如何,但它是有效的。我也欢迎任何关于如何改进的建议。

相关问题