关于重构scala的建议-我能消除foreach循环中使用的变量吗

vm0i2vca  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(433)

我想知道如何重构一些scala代码,使之更优雅、更地道。
我有一个函数

def joinDataFramesOnColumns(joinColumns: Seq[String]) : org.apache.spark.sql.DataFrame

在一个 Seq[org.apache.spark.sql.DataFrame] 把他们连在一起 joinColumns . 函数定义如下:

implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
    def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
      val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
      val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
      if (nonEmptyDataFrames.isEmpty){
        emptyDataFrame
      }
      else {
        if (joinColumns.isEmpty) {
          return nonEmptyDataFrames.reduce(_.crossJoin(_))
        }
      nonEmptyDataFrames.reduce(_.join(_, joinColumns))
    }
  }
}

我有一些单元测试都成功了:

class FeatureGeneratorDataFrameExtensionsTest extends WordSpec {
  val fruitValues = Seq(
    Row(0, "BasketA", "Bananas", "Jack"),
    Row(2, "BasketB", "Oranges", "Jack"),
    Row(2, "BasketC", "Oranges", "Jill"),
    Row(3, "BasketD", "Oranges", "Jack"),
    Row(4, "BasketE", "Oranges", "Jack"),
    Row(4, "BasketE", "Apples", "Jack"),
    Row(4, "BasketF", "Bananas", "Jill")
  )
  val schema = List(
    StructField("weeksPrior", IntegerType, true),
    StructField("basket", StringType, true),
    StructField("Product", StringType, true),
    StructField("Customer", StringType, true)
  )
  val fruitDf = spark.createDataFrame(
    spark.sparkContext.parallelize(fruitValues),
    StructType(schema)
  ).withColumn("Date", udfDateSubWeeks(lit(dayPriorToAsAt), col("weeksPrior")))

  "FeatureGenerator.SequenceOfDataFrames" should {
    "join multiple dataframes on a specified set of columns" in {
      val sequenceOfDataFrames = Seq[DataFrame](
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior1"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior2"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior3"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior4"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior5")
      )
      val joinedDataFrames = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product", "Customer", "Date"))
      assert(joinedDataFrames.columns.length === 9)
      assert(joinedDataFrames.columns.contains("basket"))
      assert(joinedDataFrames.columns.contains("Product"))
      assert(joinedDataFrames.columns.contains("Customer"))
      assert(joinedDataFrames.columns.contains("Date"))
      assert(joinedDataFrames.columns.contains("weeksPrior1"))
      assert(joinedDataFrames.columns.contains("weeksPrior2"))
      assert(joinedDataFrames.columns.contains("weeksPrior3"))
      assert(joinedDataFrames.columns.contains("weeksPrior4"))
      assert(joinedDataFrames.columns.contains("weeksPrior5"))
    }
    "when passed a list of one dataframe return that same dataframe" in {
      val sequenceOfDataFrames = Seq[DataFrame](fruitDf)
      val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product"))
      assert(joinedDataFrame.columns.sorted === fruitDf.columns.sorted)
      assert(joinedDataFrame.count === fruitDf.count)
    }
    "when passed an empty list of dataframes return an empty dataframe" in {
      val joinedDataFrame = Seq[DataFrame]().joinDataFramesOnColumns(Seq("basket"))
      assert(joinedDataFrame === spark.emptyDataFrame)
    }
    "when passed an empty list of joinColumns return the dataframes crossjoined" in {
      val sequenceOfDataFrames = Seq[DataFrame](fruitDf,fruitDf, fruitDf)
      val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq[String]())
      assert(joinedDataFrame.count === scala.math.pow(fruitDf.count, sequenceOfDataFrames.size))
      assert(joinedDataFrame.columns.size === fruitDf.columns.size * sequenceOfDataFrames.size)
    }
  }
}

这一切都很好,直到它开始错误,由于这个Spark错误:https://issues.apache.org/jira/browse/spark-25150 当连接列具有相同名称时,在某些情况下可能会导致错误。
解决方法是将列别名为其他内容,因此我重新编写了类似so的函数,该函数为联接列别名,执行联接,然后将它们重命名回:

implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
    def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
      val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
      val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
      if (nonEmptyDataFrames.isEmpty){
        emptyDataFrame
      }
      else {
        if (joinColumns.isEmpty) {
          return nonEmptyDataFrames.reduce(_.crossJoin(_))
        }

      /*
      The horrible, gnarly, unelegent code below  would ideally exist simply as:

      nonEmptyDataFrames.reduce(_.join(_, joinColumns))

      however that will fail in certain specific circumstances due to a bug in spark,
      see https://issues.apache.org/jira/browse/SPARK-25150 for details
       */
      val aliasSuffix = "_aliased"
      val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
      var aliasedNonEmptyDataFrames: Seq[DataFrame] = Seq()
      nonEmptyDataFrames.foreach(
        nonEmptyDataFrame =>{
          var tempNonEmptyDataFrame = nonEmptyDataFrame
          joinColumns.foreach(
            joinColumn => {
              tempNonEmptyDataFrame = tempNonEmptyDataFrame.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
            }
          )
          aliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames :+ tempNonEmptyDataFrame
        }
      )
      var joinedAliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames.reduce(_.join(_, aliasedJoinColumns))
      joinColumns.foreach(
        joinColumn => joinedAliasedNonEmptyDataFrames = joinedAliasedNonEmptyDataFrames.withColumnRenamed(
          joinColumn+aliasSuffix, joinColumn
        )
      )
      joinedAliasedNonEmptyDataFrames
    }
  }
}

考试还是通过了,所以我对它相当满意,但我正在看那些 var 以及将结果赋给它的循环 var 在每次迭代中。。。并发现它们相当不雅,相当丑陋,尤其是与原始版本的功能相比。我觉得一定有办法写这个,这样我就不用用了 var s、 但经过反复试验,这是我能做的最好的了。
有人能提出一个更优雅的解决方案吗?作为一个scala开发新手,它将帮助我更加熟悉解决此类问题的惯用方法。
对于代码的其他部分(如测试)的任何建设性意见也将受到欢迎

3bygqnnd

3bygqnnd1#

感谢@duelist,他的使用foldleft()的建议让我了解了scala中的foldleft是如何在dataframe上工作的?这反过来又导致我像这样修改代码以消除 var 学生:

implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
    def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
      val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
      val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
      if (nonEmptyDataFrames.isEmpty){
        emptyDataFrame
      }
      else {
        if (joinColumns.isEmpty) {
          return nonEmptyDataFrames.reduce(_.crossJoin(_))
        }

        /*
        The code below  would ideally exist simply as:

        nonEmptyDataFrames.reduce(_.join(_, joinColumns))

        however that will fail in certain specific circumstances due to a bug in spark,
        see https://issues.apache.org/jira/browse/SPARK-25150 for details

        hence this code aliases the joinColumns, performs the join, then renames the 
        aliased columns back to their original name
         */
        val aliasSuffix = "_aliased"
        val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
        val joinedAliasedNonEmptyDataFrames = nonEmptyDataFrames.foldLeft(Seq[DataFrame]()){
          (tempDf, nonEmptyDataFrame) => tempDf :+ joinColumns.foldLeft(nonEmptyDataFrame){
            (tempDf2, joinColumn) => tempDf2.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
          }
        }.reduce(_.join(_, aliasedJoinColumns))
        joinColumns.foldLeft(joinedAliasedNonEmptyDataFrames){
          (tempDf, joinColumn) => tempDf.withColumnRenamed(joinColumn+aliasSuffix, joinColumn)
        }
      }
    }
  }

我本可以把两种说法合二为一,从而消除 val joinedAliasedNonEmptyDataFrames 但我更喜欢用这种过渡方式带来的清晰 val .

相关问题