序列化—是否有任何方法可以序列化spark ml管道中的自定义转换器

zd287kbt  于 2021-05-27  发布在  Spark
关注(0)|答案(2)|浏览(467)

我使用ml管道和各种自定义的基于udf的转换器。我要寻找的是一种序列化/反序列化此管道的方法。
我使用

ObjectOutputStream.write()

但是,每当我尝试反序列化管道时:

java.lang.ClassNotFoundException: org.sparkexample.DateTransformer

datetransformer在哪里是我的自定义转换器。是否有任何方法/接口可以实现正确的序列化?
我发现有

MLWritable

接口可能由我的类实现(datetransformer扩展transfrormer),但是找不到有用的示例。

smdncfj3

smdncfj31#

如果您使用的是spark 2.x+,则使用defaultparamswritable扩展您的转换器
例如

class ProbabilityMaxer extends Transformer with DefaultParamsWritable{

然后用一个字符串参数创建一个构造函数

def this(_uid: String) {
    this()
  }

最后为成功阅读添加一个同伴类

object ProbabilityMaxer extends  DefaultParamsReadable[ProbabilityMaxer]

我在我的生产服务器上工作。我将添加gitlab链接到项目后,当我上传它

iklwldmw

iklwldmw2#

简而言之,你不能,至少不容易。
开发人员已经尽力使添加一个新的变压器/估计器变得尽可能困难。基本上所有的东西 org.apache.spark.ml.util.ReadWrite 是私人的(除了 MLWritable 以及 MLReadable )因此没有办法使用那里的任何实用方法/类/对象。还有(我肯定你已经发现)绝对没有关于如何做到这一点的文档,但是嘿,好的代码文档本身是不是?!
从中挖掘代码 org.apache.spark.ml.util.ReadWrite 以及 org.apache.spark.ml.feature.HashingTF 看来你需要重写 MLWritable.write 以及 MLReadable.read . 这个 DefaultParamsWriter 以及 DefaultParamsReader 其中似乎包含实际的save/load实现,正在保存和加载一组元数据:

时间戳
sparkversion公司
uid编号
参数Map
(可选,额外元数据)
因此,任何实现都至少需要涵盖这些内容,而不需要学习任何模型的转换器可能就可以做到这一点。一个需要拟合的模型也需要在它的实现中保存数据 save/write -例如,这是 LocalLDAModel 是吗(https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/clustering/lda.scala#l523)所以学习的模型只是保存为Parquet文件(似乎)

val data = sqlContext.read.parquet(dataPath)
        .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
          "gammaShape")
        .head()

作为测试,我从 org.apache.spark.ml.util.ReadWrite 这似乎是必要的,并测试了以下变压器没有做任何有用的。
警告:这几乎肯定是错误的做法,而且很可能在将来被打破。我真诚地希望我误解了一些东西,有人会纠正我如何实际创建一个可以序列化/反序列化的转换器
这是spark 1.6.3的版本,如果您使用的是2.x,它可能已经损坏了

import org.apache.spark.sql.types._
import org.apache.spark.ml.param._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{SQLContext, DataFrame}
import org.apache.spark.mllib.linalg._

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

object CustomTransform extends DefaultParamsReadable[CustomTransform] {
  /* Companion object for deserialisation */
  override def load(path: String): CustomTransform = super.load(path)
}

class CustomTransform(override val uid: String)
  extends Transformer with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("customThing"))

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)
  def getOutputCol(): String = getOrDefault(outputCol)

  val inputCol = new Param[String](this, "inputCol", "input column")
  val outputCol = new Param[String](this, "outputCol", "output column")

  override def transform(dataset: DataFrame): DataFrame = {
    val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
    import sqlContext.implicits._

    val outCol = extractParamMap.getOrElse(outputCol, "output")
    val inCol = extractParamMap.getOrElse(inputCol, "input")
    val transformUDF = udf({ vector: SparseVector =>
      vector.values.map( _ * 10 )
      // WHAT EVER YOUR TRANSFORMER NEEDS TO DO GOES HERE
    })

    dataset.withColumn(outCol, transformUDF(col(inCol)))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    val outputFields = schema.fields :+ StructField(extractParamMap.getOrElse(outputCol, "filtered"), new VectorUDT, nullable = false)
    StructType(outputFields)
  }
}

那我们需要所有的公用设施 org.apache.spark.ml.util.ReadWrite https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/util/readwrite.scala

trait DefaultParamsWritable extends MLWritable { self: Params =>
  override def write: MLWriter = new DefaultParamsWriter(this)
}

trait DefaultParamsReadable[T] extends MLReadable[T] {
  override def read: MLReader[T] = new DefaultParamsReader
}

class DefaultParamsWriter(instance: Params) extends MLWriter {
  override protected def saveImpl(path: String): Unit = {
    DefaultParamsWriter.saveMetadata(instance, path, sc)
  }
}

object DefaultParamsWriter {

  /**
    * Saves metadata + Params to: path + "/metadata"
    *  - class
    *  - timestamp
    *  - sparkVersion
    *  - uid
    *  - paramMap
    *  - (optionally, extra metadata)
    * @param extraMetadata  Extra metadata to be saved at same level as uid, paramMap, etc.
    * @param paramMap  If given, this is saved in the "paramMap" field.
    *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
    *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
    */
  def saveMetadata(
  instance: Params,
  path: String,
  sc: SparkContext,
  extraMetadata: Option[JObject] = None,
  paramMap: Option[JValue] = None): Unit = {
    val uid = instance.uid
    val cls = instance.getClass.getName
    val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
    val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
      p.name -> parse(p.jsonEncode(v))
    }.toList))
    val basicMetadata = ("class" -> cls) ~
    ("timestamp" -> System.currentTimeMillis()) ~
    ("sparkVersion" -> sc.version) ~
    ("uid" -> uid) ~
    ("paramMap" -> jsonParams)
    val metadata = extraMetadata match {
      case Some(jObject) =>
        basicMetadata ~ jObject
      case None =>
        basicMetadata
    }
    val metadataPath = new Path(path, "metadata").toString
    val metadataJson = compact(render(metadata))
    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
  }
}

class DefaultParamsReader[T] extends MLReader[T] {
  override def load(path: String): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
    val instance =
    cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
    DefaultParamsReader.getAndSetParams(instance, metadata)
    instance.asInstanceOf[T]
  }
}

object DefaultParamsReader {

  /**
    * All info from metadata file.
    *
    * @param params       paramMap, as a [[JValue]]
    * @param metadata     All metadata, including the other fields
    * @param metadataJson Full metadata file String (for debugging)
    */
  case class Metadata(
                       className: String,
                       uid: String,
                       timestamp: Long,
                       sparkVersion: String,
                       params: JValue,
                       metadata: JValue,
                       metadataJson: String)

  /**
    * Load metadata from file.
    *
    * @param expectedClassName If non empty, this is checked against the loaded metadata.
    * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
    */
  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
    val metadataPath = new Path(path, "metadata").toString
    val metadataStr = sc.textFile(metadataPath, 1).first()
    val metadata = parse(metadataStr)

    implicit val format = DefaultFormats
    val className = (metadata \ "class").extract[String]
    val uid = (metadata \ "uid").extract[String]
    val timestamp = (metadata \ "timestamp").extract[Long]
    val sparkVersion = (metadata \ "sparkVersion").extract[String]
    val params = metadata \ "paramMap"
    if (expectedClassName.nonEmpty) {
      require(className == expectedClassName, s"Error loading metadata: Expected class name" +
        s" $expectedClassName but found class name $className")
    }

    Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
  }

  /**
    * Extract Params from metadata, and set them in the instance.
    * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
    */
  def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
    implicit val format = DefaultFormats
    metadata.params match {
      case JObject(pairs) =>
        pairs.foreach { case (paramName, jsonValue) =>
          val param = instance.getParam(paramName)
          val value = param.jsonDecode(compact(render(jsonValue)))
          instance.set(param, value)
        }
      case _ =>
        throw new IllegalArgumentException(
          s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
    }
  }

  /**
    * Load a [[Params]] instance from the given path, and return it.
    * This assumes the instance implements [[MLReadable]].
    */
  def loadParamsInstance[T](path: String, sc: SparkContext): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
  }
}

有了它,你就可以使用 CustomTransformer 在一个 Pipeline 保存/加载管道。我在sparkshell中很快测试了它,它似乎可以工作,但肯定不是很好。

相关问题